github.com/brandonmartin/migrate/v4@v4.14.2/database/oracle/oracle.go (about) 1 package oracle 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "database/sql" 8 "fmt" 9 "io" 10 "io/ioutil" 11 nurl "net/url" 12 "strings" 13 14 "github.com/godror/godror" 15 16 _ "github.com/godror/godror" 17 "github.com/golang-migrate/migrate/v4" 18 "github.com/golang-migrate/migrate/v4/database" 19 multierror "github.com/hashicorp/go-multierror" 20 ) 21 22 func init() { 23 db := Oracle{} 24 database.Register("oracle", &db) 25 } 26 27 const ( 28 defaultMigrationsTable = "SCHEMA_MIGRATIONS" 29 defaultStatementSeparator = ";" 30 plsqlDefaultStatementSeparator = "---" 31 plsqlStatementEndToken = "END;" 32 ) 33 34 var ( 35 ErrNilConfig = fmt.Errorf("no config") 36 ErrNoDatabaseName = fmt.Errorf("no database name") 37 ) 38 39 type Config struct { 40 MigrationsTable string 41 DisableMultiStatements bool 42 PLSQLStatementSeparator string 43 44 databaseName string 45 } 46 47 type Oracle struct { 48 // Locking and unlocking need to use the same connection 49 conn *sql.Conn 50 db *sql.DB 51 isLocked bool 52 53 // Open and WithInstance need to guarantee that config is never nil 54 config *Config 55 } 56 57 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 58 if config == nil { 59 return nil, ErrNilConfig 60 } 61 62 if err := instance.Ping(); err != nil { 63 return nil, err 64 } 65 66 query := `SELECT SYS_CONTEXT('USERENV','DB_NAME') FROM DUAL` 67 var dbName string 68 if err := instance.QueryRow(query).Scan(&dbName); err != nil { 69 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 70 } 71 72 if dbName == "" { 73 return nil, ErrNoDatabaseName 74 } 75 76 config.databaseName = dbName 77 78 if config.MigrationsTable == "" { 79 config.MigrationsTable = defaultMigrationsTable 80 } 81 82 if config.PLSQLStatementSeparator == "" { 83 config.PLSQLStatementSeparator = plsqlDefaultStatementSeparator 84 } 85 86 conn, err := instance.Conn(context.Background()) 87 88 if err != nil { 89 return nil, err 90 } 91 92 ora := &Oracle{ 93 conn: conn, 94 db: instance, 95 config: config, 96 } 97 98 if err := ora.ensureVersionTable(); err != nil { 99 return nil, err 100 } 101 102 return ora, nil 103 } 104 105 func (ora *Oracle) Open(url string) (database.Driver, error) { 106 purl, err := nurl.Parse(url) 107 if err != nil { 108 return nil, err 109 } 110 db, err := sql.Open("godror", migrate.FilterCustomQuery(purl).String()) 111 if err != nil { 112 return nil, err 113 } 114 115 migrationsTable := strings.ToUpper(purl.Query().Get("x-migrations-table")) 116 statementSeparator := purl.Query().Get("x-statement-separator") 117 disableMultiStatement := false 118 if purl.Query().Get("x-disable-multi-statements") == "true" { 119 disableMultiStatement = true 120 } 121 122 oraInst, err := WithInstance(db, &Config{ 123 databaseName: purl.Path, 124 MigrationsTable: migrationsTable, 125 DisableMultiStatements: disableMultiStatement, 126 PLSQLStatementSeparator: statementSeparator, 127 }) 128 129 if err != nil { 130 return nil, err 131 } 132 133 return oraInst, nil 134 } 135 136 func (ora *Oracle) Close() error { 137 connErr := ora.conn.Close() 138 dbErr := ora.db.Close() 139 if connErr != nil || dbErr != nil { 140 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 141 } 142 return nil 143 } 144 145 func (ora *Oracle) Lock() error { 146 if ora.isLocked { 147 return database.ErrLocked 148 } 149 150 // https://docs.oracle.com/cd/B28359_01/appdev.111/b28419/d_lock.htm#ARPLS021 151 query := ` 152 declare 153 v_lockhandle varchar2(200); 154 v_result number; 155 begin 156 157 dbms_lock.allocate_unique('control_lock', v_lockhandle); 158 159 v_result := dbms_lock.request(v_lockhandle, dbms_lock.x_mode); 160 161 if v_result <> 0 then 162 dbms_output.put_line( 163 case 164 when v_result=1 then 'Timeout' 165 when v_result=2 then 'Deadlock' 166 when v_result=3 then 'Parameter Error' 167 when v_result=4 then 'Already owned' 168 when v_result=5 then 'Illegal Lock Handle' 169 end); 170 end if; 171 172 end; 173 ` 174 if _, err := ora.conn.ExecContext(context.Background(), query); err != nil { 175 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 176 } 177 178 ora.isLocked = true 179 return nil 180 } 181 182 func (ora *Oracle) Unlock() error { 183 if !ora.isLocked { 184 return nil 185 } 186 187 query := ` 188 declare 189 v_lockhandle varchar2(200); 190 v_result number; 191 begin 192 193 dbms_lock.allocate_unique('control_lock', v_lockhandle); 194 195 v_result := dbms_lock.release(v_lockhandle); 196 197 if v_result <> 0 then 198 dbms_output.put_line( 199 case 200 when v_result=1 then 'Timeout' 201 when v_result=2 then 'Deadlock' 202 when v_result=3 then 'Parameter Error' 203 when v_result=4 then 'Already owned' 204 when v_result=5 then 'Illegal Lock Handle' 205 end); 206 end if; 207 208 end; 209 ` 210 if _, err := ora.conn.ExecContext(context.Background(), query); err != nil { 211 return &database.Error{OrigErr: err, Query: []byte(query)} 212 } 213 ora.isLocked = false 214 return nil 215 } 216 217 func (ora *Oracle) Run(migration io.Reader) error { 218 queries, err := parseStatements(migration, ora.config) 219 if err != nil { 220 return err 221 } 222 for _, query := range queries { 223 if _, err := ora.conn.ExecContext(context.Background(), query); err != nil { 224 if oraErr, ok := godror.AsOraErr(err); ok { 225 return database.Error{OrigErr: oraErr, Err: oraErr.Message(), Query: []byte(query)} 226 } 227 return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(query)} 228 } 229 } 230 231 return nil 232 } 233 234 func (ora *Oracle) SetVersion(version int, dirty bool) error { 235 tx, err := ora.conn.BeginTx(context.Background(), &sql.TxOptions{}) 236 if err != nil { 237 return &database.Error{OrigErr: err, Err: "transaction start failed"} 238 } 239 240 query := "TRUNCATE TABLE " + ora.config.MigrationsTable 241 if _, err := tx.Exec(query); err != nil { 242 if errRollback := tx.Rollback(); errRollback != nil { 243 err = multierror.Append(err, errRollback) 244 } 245 return &database.Error{OrigErr: err, Query: []byte(query)} 246 } 247 248 if version >= 0 { 249 query = `INSERT INTO ` + ora.config.MigrationsTable + ` (VERSION, DIRTY) VALUES (:1, :2)` 250 if _, err := tx.Exec(query, version, b2i(dirty)); err != nil { 251 if errRollback := tx.Rollback(); errRollback != nil { 252 err = multierror.Append(err, errRollback) 253 } 254 return &database.Error{OrigErr: err, Query: []byte(query)} 255 } 256 } 257 258 if err := tx.Commit(); err != nil { 259 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 260 } 261 262 return nil 263 } 264 265 func (ora *Oracle) Version() (version int, dirty bool, err error) { 266 query := "SELECT VERSION, DIRTY FROM " + ora.config.MigrationsTable + " WHERE ROWNUM = 1 ORDER BY VERSION desc" 267 err = ora.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 268 switch { 269 case err == sql.ErrNoRows: 270 return database.NilVersion, false, nil 271 272 case err != nil: 273 if _, ok := godror.AsOraErr(err); ok { 274 return database.NilVersion, false, nil 275 } 276 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 277 278 default: 279 return version, dirty, nil 280 } 281 } 282 283 func (ora *Oracle) Drop() (err error) { 284 // select all tables in current schema 285 query := fmt.Sprintf(`SELECT TABLE_NAME FROM USER_TABLES`) 286 tables, err := ora.conn.QueryContext(context.Background(), query) 287 if err != nil { 288 return &database.Error{OrigErr: err, Query: []byte(query)} 289 } 290 defer func() { 291 if errClose := tables.Close(); errClose != nil { 292 err = multierror.Append(err, errClose) 293 } 294 }() 295 296 // delete one table after another 297 tableNames := make([]string, 0) 298 for tables.Next() { 299 var tableName string 300 if err := tables.Scan(&tableName); err != nil { 301 return err 302 } 303 if len(tableName) > 0 { 304 tableNames = append(tableNames, tableName) 305 } 306 } 307 308 query = ` 309 BEGIN 310 EXECUTE IMMEDIATE 'DROP TABLE %s'; 311 EXCEPTION 312 WHEN OTHERS THEN 313 IF SQLCODE != -942 THEN 314 RAISE; 315 END IF; 316 END; 317 ` 318 if len(tableNames) > 0 { 319 // delete one by one ... 320 for _, t := range tableNames { 321 if _, err := ora.conn.ExecContext(context.Background(), fmt.Sprintf(query, t)); err != nil { 322 return &database.Error{OrigErr: err, Query: []byte(query)} 323 } 324 } 325 } 326 327 return nil 328 } 329 330 // ensureVersionTable checks if versions table exists and, if not, creates it. 331 // Note that this function locks the database, which deviates from the usual 332 // convention of "caller locks" in the Postgres type. 333 func (ora *Oracle) ensureVersionTable() (err error) { 334 if err = ora.Lock(); err != nil { 335 return err 336 } 337 338 defer func() { 339 if e := ora.Unlock(); e != nil { 340 if err == nil { 341 err = e 342 } else { 343 err = multierror.Append(err, e) 344 } 345 } 346 }() 347 348 query := ` 349 declare 350 v_sql LONG; 351 begin 352 353 v_sql:='create table %s 354 ( 355 VERSION NUMBER(20) NOT NULL PRIMARY KEY, 356 DIRTY NUMBER(1) NOT NULL 357 )'; 358 execute immediate v_sql; 359 360 EXCEPTION 361 WHEN OTHERS THEN 362 IF SQLCODE = -955 THEN 363 NULL; -- suppresses ORA-00955 exception 364 ELSE 365 RAISE; 366 END IF; 367 END; 368 ` 369 if _, err = ora.conn.ExecContext(context.Background(), fmt.Sprintf(query, ora.config.MigrationsTable)); err != nil { 370 return &database.Error{OrigErr: err, Query: []byte(query)} 371 } 372 373 return nil 374 } 375 376 func parseStatements(rd io.Reader, c *Config) ([]string, error) { 377 migr, err := ioutil.ReadAll(rd) 378 if err != nil { 379 return nil, err 380 } 381 382 // If multi-statements has been disable explicitly, 383 // i.e, there is no multi-statement enabled(neither normal multi-statements nor multi-PL/SQL-statements), 384 // return the whole migration as a blob. 385 if c.DisableMultiStatements { 386 return []string{string(migr)}, nil 387 } 388 389 // Either normal multi-statements or multi-PL/SQL-statements has been enabled. 390 plsqlEnabled := false 391 if strings.Contains(string(migr), plsqlStatementEndToken) { 392 plsqlEnabled = true 393 } 394 var queries []string 395 var buf bytes.Buffer 396 scanner := bufio.NewScanner(bytes.NewBuffer(migr)) 397 scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) 398 for scanner.Scan() { 399 line := scanner.Text() 400 if plsqlEnabled && line == c.PLSQLStatementSeparator { 401 query := buf.String() 402 if query != "" { 403 queries = append(queries, query) 404 } 405 buf.Reset() 406 } 407 // ignore comment 408 if strings.HasPrefix(line, "--") { 409 continue 410 } 411 if _, err := buf.WriteString(line + "\n"); err != nil { 412 return nil, err 413 } 414 } 415 if plsqlEnabled { 416 query := buf.String() 417 if query != "" { 418 queries = append(queries, query) 419 } 420 } else { 421 queries = strings.Split(buf.String(), defaultStatementSeparator) 422 } 423 424 results := make([]string, 0) 425 sLen := len(plsqlStatementEndToken) 426 for _, query := range queries { 427 query = strings.TrimSpace(query) 428 query = strings.TrimPrefix(query, "\n") 429 query = strings.TrimSuffix(query, "\n") 430 if len(query) > sLen && strings.ToUpper(query[len(query)-sLen:]) != plsqlStatementEndToken { 431 query = strings.TrimSuffix(query, ";") 432 } 433 if query == "" { 434 continue 435 } 436 results = append(results, query) 437 } 438 return results, nil 439 } 440 441 func b2i(b bool) int { 442 if b { 443 return 1 444 } 445 return 0 446 }