github.com/wfusion/gofusion@v1.1.14/common/infra/drivers/orm/opengauss/migrator.go (about) 1 package opengauss 2 3 import ( 4 "database/sql" 5 "fmt" 6 "regexp" 7 "strings" 8 9 "gorm.io/gorm" 10 "gorm.io/gorm/clause" 11 "gorm.io/gorm/migrator" 12 "gorm.io/gorm/schema" 13 ) 14 15 const indexSql = ` 16 select 17 t.relname as table_name, 18 i.relname as index_name, 19 a.attname as column_name, 20 ix.indisunique as non_unique, 21 ix.indisprimary as primary 22 from 23 pg_class t, 24 pg_class i, 25 pg_index ix, 26 pg_attribute a 27 where 28 t.oid = ix.indrelid 29 and i.oid = ix.indexrelid 30 and a.attrelid = t.oid 31 and a.attnum = ANY(ix.indkey) 32 and t.relkind = 'r' 33 and t.relname = ? 34 ` 35 36 var typeAliasMap = map[string][]string{ 37 "int2": {"smallint"}, 38 "int4": {"integer"}, 39 "int8": {"bigint"}, 40 "smallint": {"int2"}, 41 "integer": {"int4"}, 42 "bigint": {"int8"}, 43 "decimal": {"numeric"}, 44 "numeric": {"decimal"}, 45 "timestamptz": {"timestamp with time zone"}, 46 "timestamp with time zone": {"timestamptz"}, 47 } 48 49 type Migrator struct { 50 migrator.Migrator 51 } 52 53 func (m Migrator) CurrentDatabase() (name string) { 54 m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name) 55 return 56 } 57 58 func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { 59 for _, opt := range opts { 60 str := stmt.Quote(opt.DBName) 61 if opt.Expression != "" { 62 str = opt.Expression 63 } 64 65 if opt.Collate != "" { 66 str += " COLLATE " + opt.Collate 67 } 68 69 if opt.Sort != "" { 70 str += " " + opt.Sort 71 } 72 results = append(results, clause.Expr{SQL: str}) 73 } 74 return 75 } 76 77 func (m Migrator) HasIndex(value interface{}, name string) bool { 78 var count int64 79 m.RunWithValue(value, func(stmt *gorm.Statement) error { 80 if stmt.Schema != nil { 81 if idx := stmt.Schema.LookIndex(name); idx != nil { 82 name = idx.Name 83 } 84 } 85 currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) 86 return m.DB.Raw( 87 "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema, 88 ).Scan(&count).Error 89 }) 90 91 return count > 0 92 } 93 94 func (m Migrator) CreateIndex(value interface{}, name string) error { 95 if !m.HasIndex(value, name) { 96 return nil 97 } 98 99 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 100 if stmt.Schema != nil { 101 if idx := stmt.Schema.LookIndex(name); idx != nil { 102 opts := m.BuildIndexOptions(idx.Fields, stmt) 103 values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} 104 createIndexSQL := "CREATE " 105 if idx.Class != "" { 106 createIndexSQL += idx.Class + " " 107 } 108 createIndexSQL += "INDEX " 109 110 if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" { 111 createIndexSQL += "CONCURRENTLY " 112 } 113 114 createIndexSQL += "? ON ?" 115 116 if idx.Type != "" { 117 createIndexSQL += " USING " + idx.Type + "(?)" 118 } else { 119 createIndexSQL += " ?" 120 } 121 122 if idx.Where != "" { 123 createIndexSQL += " WHERE " + idx.Where 124 } 125 126 return m.DB.Exec(createIndexSQL, values...).Error 127 } 128 } 129 130 return fmt.Errorf("failed to create index with name %v", name) 131 }) 132 } 133 134 func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { 135 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 136 return m.DB.Exec( 137 "ALTER INDEX ? RENAME TO ?", 138 clause.Column{Name: oldName}, clause.Column{Name: newName}, 139 ).Error 140 }) 141 } 142 143 func (m Migrator) DropIndex(value interface{}, name string) error { 144 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 145 if stmt.Schema != nil { 146 if idx := stmt.Schema.LookIndex(name); idx != nil { 147 name = idx.Name 148 } 149 } 150 151 return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error 152 }) 153 } 154 155 func (m Migrator) GetTables() (tableList []string, err error) { 156 currentSchema, _ := m.CurrentSchema(m.DB.Statement, "") 157 return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error 158 } 159 160 func (m Migrator) CreateTable(values ...interface{}) (err error) { 161 if err = m.Migrator.CreateTable(values...); err != nil { 162 return 163 } 164 for _, value := range m.ReorderModels(values, false) { 165 if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { 166 if stmt.Schema != nil { 167 for _, fieldName := range stmt.Schema.DBNames { 168 field := stmt.Schema.FieldsByDBName[fieldName] 169 if field.Comment != "" { 170 if err := m.DB.Exec( 171 "COMMENT ON COLUMN ?.? IS ?", 172 m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), 173 ).Error; err != nil { 174 return err 175 } 176 } 177 } 178 } 179 return nil 180 }); err != nil { 181 return 182 } 183 } 184 return 185 } 186 187 func (m Migrator) HasTable(value interface{}) bool { 188 var count int64 189 m.RunWithValue(value, func(stmt *gorm.Statement) error { 190 currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) 191 return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error 192 }) 193 return count > 0 194 } 195 196 func (m Migrator) DropTable(values ...interface{}) error { 197 values = m.ReorderModels(values, false) 198 tx := m.DB.Session(&gorm.Session{}) 199 for i := len(values) - 1; i >= 0; i-- { 200 if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { 201 return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", m.CurrentTable(stmt)).Error 202 }); err != nil { 203 return err 204 } 205 } 206 return nil 207 } 208 209 func (m Migrator) AddColumn(value interface{}, field string) error { 210 if err := m.Migrator.AddColumn(value, field); err != nil { 211 return err 212 } 213 m.resetPreparedStmts() 214 215 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 216 if stmt.Schema != nil { 217 if field := stmt.Schema.LookUpField(field); field != nil { 218 if field.Comment != "" { 219 if err := m.DB.Exec( 220 "COMMENT ON COLUMN ?.? IS ?", 221 m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), 222 ).Error; err != nil { 223 return err 224 } 225 } 226 } 227 } 228 return nil 229 }) 230 } 231 232 func (m Migrator) HasColumn(value interface{}, field string) bool { 233 var count int64 234 m.RunWithValue(value, func(stmt *gorm.Statement) error { 235 name := field 236 if stmt.Schema != nil { 237 if field := stmt.Schema.LookUpField(field); field != nil { 238 name = field.DBName 239 } 240 } 241 242 currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) 243 return m.DB.Raw( 244 "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", 245 currentSchema, curTable, name, 246 ).Scan(&count).Error 247 }) 248 249 return count > 0 250 } 251 252 func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { 253 // skip primary field 254 if !field.PrimaryKey { 255 if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil { 256 return err 257 } 258 } 259 260 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 261 var description string 262 currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) 263 values := []interface{}{currentSchema, curTable, field.DBName, stmt.Table, currentSchema} 264 checkSQL := "SELECT description FROM pg_catalog.pg_description " 265 checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) " 266 checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = " 267 checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))" 268 m.DB.Raw(checkSQL, values...).Scan(&description) 269 270 comment := strings.Trim(field.Comment, "'") 271 comment = strings.Trim(comment, `"`) 272 if field.Comment != "" && comment != description { 273 if err := m.DB.Exec( 274 "COMMENT ON COLUMN ?.? IS ?", 275 m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), 276 ).Error; err != nil { 277 return err 278 } 279 } 280 return nil 281 }) 282 } 283 284 // AlterColumn alter value's `field` column' type based on schema definition 285 func (m Migrator) AlterColumn(value interface{}, field string) error { 286 err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 287 if stmt.Schema != nil { 288 if field := stmt.Schema.LookUpField(field); field != nil { 289 var ( 290 columnTypes, _ = m.DB.Migrator().ColumnTypes(value) 291 fieldColumnType *migrator.ColumnType 292 ) 293 for _, columnType := range columnTypes { 294 if columnType.Name() == field.DBName { 295 fieldColumnType, _ = columnType.(*migrator.ColumnType) 296 } 297 } 298 299 fileType := clause.Expr{SQL: m.DataTypeOf(field)} 300 // check for typeName and SQL name 301 isSameType := true 302 if fieldColumnType.DatabaseTypeName() != fileType.SQL { 303 isSameType = false 304 // if different, also check for aliases 305 aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName()) 306 for _, alias := range aliases { 307 if strings.HasPrefix(fileType.SQL, alias) { 308 isSameType = true 309 break 310 } 311 } 312 } 313 314 // not same, migrate 315 if !isSameType { 316 filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement() 317 if field.AutoIncrement && filedColumnAutoIncrement { // update 318 serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) 319 if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType { 320 if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { 321 return err 322 } 323 } 324 } else if field.AutoIncrement && !filedColumnAutoIncrement { // create 325 serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) 326 if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { 327 return err 328 } 329 } else if !field.AutoIncrement && filedColumnAutoIncrement { // delete 330 if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil { 331 return err 332 } 333 } else { 334 if err := m.modifyColumn(stmt, field, fileType, fieldColumnType); err != nil { 335 return err 336 } 337 } 338 } 339 340 if null, _ := fieldColumnType.Nullable(); null == field.NotNull { 341 if field.NotNull { 342 if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { 343 return err 344 } 345 } else { 346 if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { 347 return err 348 } 349 } 350 } 351 352 if uniq, _ := fieldColumnType.Unique(); !uniq && field.Unique { 353 idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)} 354 // Not a unique constraint but a unique index 355 if !m.HasIndex(stmt.Table, idxName.Name) { 356 if err := m.DB.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil { 357 return err 358 } 359 } 360 } 361 362 if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue { 363 if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { 364 if field.DefaultValueInterface != nil { 365 defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} 366 m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) 367 if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil { 368 return err 369 } 370 } else if field.DefaultValue != "(-)" { 371 if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { 372 return err 373 } 374 } else { 375 if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { 376 return err 377 } 378 } 379 } 380 } 381 return nil 382 } 383 } 384 return fmt.Errorf("failed to look up field with name: %s", field) 385 }) 386 387 if err != nil { 388 return err 389 } 390 m.resetPreparedStmts() 391 return nil 392 } 393 394 func (m Migrator) modifyColumn(stmt *gorm.Statement, field *schema.Field, targetType clause.Expr, existingColumn *migrator.ColumnType) error { 395 alterSQL := "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?" 396 isUncastableDefaultValue := false 397 398 if targetType.SQL == "boolean" { 399 switch existingColumn.DatabaseTypeName() { 400 case "int2", "int8", "numeric": 401 alterSQL = "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::int::?" 402 } 403 isUncastableDefaultValue = true 404 } 405 406 if dv, _ := existingColumn.DefaultValue(); dv != "" && isUncastableDefaultValue { 407 if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { 408 return err 409 } 410 } 411 if err := m.DB.Exec(alterSQL, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, targetType, clause.Column{Name: field.DBName}, targetType).Error; err != nil { 412 return err 413 } 414 return nil 415 } 416 417 func (m Migrator) HasConstraint(value interface{}, name string) bool { 418 var count int64 419 m.RunWithValue(value, func(stmt *gorm.Statement) error { 420 constraint, chk, table := m.GuessConstraintAndTable(stmt, name) 421 currentSchema, curTable := m.CurrentSchema(stmt, table) 422 if constraint != nil { 423 name = constraint.Name 424 } else if chk != nil { 425 name = chk.Name 426 } 427 428 return m.DB.Raw( 429 "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?", 430 currentSchema, curTable, name, 431 ).Scan(&count).Error 432 }) 433 434 return count > 0 435 } 436 437 func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { 438 columnTypes = make([]gorm.ColumnType, 0) 439 err = m.RunWithValue(value, func(stmt *gorm.Statement) error { 440 var ( 441 currentDatabase = m.DB.Migrator().CurrentDatabase() 442 currentSchema, table = m.CurrentSchema(stmt, stmt.Table) 443 columns, err = m.DB.Raw( 444 "SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?", 445 currentDatabase, currentSchema, table).Rows() 446 ) 447 448 if err != nil { 449 return err 450 } 451 452 for columns.Next() { 453 var ( 454 column = &migrator.ColumnType{ 455 PrimaryKeyValue: sql.NullBool{Valid: true}, 456 UniqueValue: sql.NullBool{Valid: true}, 457 } 458 datetimePrecision sql.NullInt64 459 radixValue sql.NullInt64 460 typeLenValue sql.NullInt64 461 identityIncrement sql.NullString 462 ) 463 464 err = columns.Scan( 465 &column.NameValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.DecimalSizeValue, 466 &radixValue, &column.ScaleValue, &datetimePrecision, &typeLenValue, &column.DefaultValueValue, &column.CommentValue, &identityIncrement, 467 ) 468 if err != nil { 469 return err 470 } 471 472 if typeLenValue.Valid && typeLenValue.Int64 > 0 { 473 column.LengthValue = typeLenValue 474 } 475 476 if (strings.HasPrefix(column.DefaultValueValue.String, "nextval('") && 477 strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)")) || (identityIncrement.Valid && identityIncrement.String != "") { 478 column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} 479 column.DefaultValueValue = sql.NullString{} 480 } 481 482 if column.DefaultValueValue.Valid { 483 column.DefaultValueValue.String = parseDefaultValueValue(column.DefaultValueValue.String) 484 } 485 486 if datetimePrecision.Valid { 487 column.DecimalSizeValue = datetimePrecision 488 } 489 490 columnTypes = append(columnTypes, column) 491 } 492 columns.Close() 493 494 // assign sql column type 495 { 496 rows, rowsErr := m.GetRows(currentSchema, table) 497 if rowsErr != nil { 498 return rowsErr 499 } 500 rawColumnTypes, err := rows.ColumnTypes() 501 if err != nil { 502 return err 503 } 504 for _, columnType := range columnTypes { 505 for _, c := range rawColumnTypes { 506 if c.Name() == columnType.Name() { 507 columnType.(*migrator.ColumnType).SQLColumnType = c 508 break 509 } 510 } 511 } 512 rows.Close() 513 } 514 515 // check primary, unique field 516 { 517 columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() 518 if err != nil { 519 return err 520 } 521 uniqueContraints := map[string]int{} 522 for columnTypeRows.Next() { 523 var constraintName string 524 columnTypeRows.Scan(&constraintName) 525 uniqueContraints[constraintName]++ 526 } 527 columnTypeRows.Close() 528 529 columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() 530 if err != nil { 531 return err 532 } 533 for columnTypeRows.Next() { 534 var name, constraintName, columnType string 535 columnTypeRows.Scan(&name, &constraintName, &columnType) 536 for _, c := range columnTypes { 537 mc := c.(*migrator.ColumnType) 538 if mc.NameValue.String == name { 539 switch columnType { 540 case "PRIMARY KEY": 541 mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} 542 case "UNIQUE": 543 if uniqueContraints[constraintName] == 1 { 544 mc.UniqueValue = sql.NullBool{Bool: true, Valid: true} 545 } 546 } 547 break 548 } 549 } 550 } 551 columnTypeRows.Close() 552 } 553 554 // check column type 555 { 556 dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type 557 FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?) 558 WHERE a.attnum > 0 -- hide internal columns 559 AND NOT a.attisdropped -- hide deleted columns 560 AND b.relname = ?`, currentSchema, table).Rows() 561 if err != nil { 562 return err 563 } 564 565 for dataTypeRows.Next() { 566 var name, dataType string 567 dataTypeRows.Scan(&name, &dataType) 568 for _, c := range columnTypes { 569 mc := c.(*migrator.ColumnType) 570 if mc.NameValue.String == name { 571 mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true} 572 // Handle array type: _text -> text[] , _int4 -> integer[] 573 // Not support array size limits and array size limits because: 574 // https://www.postgresql.org/docs/current/arrays.html#ARRAYS-DECLARATION 575 if strings.HasPrefix(mc.DataTypeValue.String, "_") { 576 mc.DataTypeValue = sql.NullString{String: dataType, Valid: true} 577 } 578 break 579 } 580 } 581 } 582 dataTypeRows.Close() 583 } 584 585 return err 586 }) 587 return 588 } 589 590 func (m Migrator) GetRows(currentSchema interface{}, table interface{}) (*sql.Rows, error) { 591 name := table.(string) 592 if _, ok := currentSchema.(string); ok { 593 name = fmt.Sprintf("%v.%v", currentSchema, table) 594 } 595 596 return m.DB.Session(&gorm.Session{}).Table(name).Limit(1).Rows() 597 } 598 599 func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (interface{}, interface{}) { 600 if strings.Contains(table, ".") { 601 if tables := strings.Split(table, `.`); len(tables) == 2 { 602 return tables[0], tables[1] 603 } 604 } 605 606 if stmt.TableExpr != nil { 607 if tables := strings.Split(stmt.TableExpr.SQL, `"."`); len(tables) == 2 { 608 return strings.TrimPrefix(tables[0], `"`), table 609 } 610 } 611 return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table 612 } 613 614 func (m Migrator) CreateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, 615 serialDatabaseType string) (err error) { 616 617 _, table := m.CurrentSchema(stmt, stmt.Table) 618 tableName := table.(string) 619 620 sequenceName := strings.Join([]string{tableName, field.DBName, "seq"}, "_") 621 if err = tx.Exec(`CREATE SEQUENCE IF NOT EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, 622 clause.Expr{SQL: serialDatabaseType}).Error; err != nil { 623 return err 624 } 625 626 if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT nextval('?')", 627 clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}, clause.Expr{SQL: sequenceName}).Error; err != nil { 628 return err 629 } 630 631 if err := tx.Exec("ALTER SEQUENCE ? OWNED BY ?.?", 632 clause.Expr{SQL: sequenceName}, clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}).Error; err != nil { 633 return err 634 } 635 return 636 } 637 638 func (m Migrator) UpdateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, 639 serialDatabaseType string) (err error) { 640 641 sequenceName, err := m.getColumnSequenceName(tx, stmt, field) 642 if err != nil { 643 return err 644 } 645 646 if err = tx.Exec(`ALTER SEQUENCE IF EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { 647 return err 648 } 649 650 if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", 651 m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { 652 return err 653 } 654 return 655 } 656 657 func (m Migrator) DeleteSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, 658 fileType clause.Expr) (err error) { 659 660 sequenceName, err := m.getColumnSequenceName(tx, stmt, field) 661 if err != nil { 662 return err 663 } 664 665 if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil { 666 return err 667 } 668 669 if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", 670 m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}).Error; err != nil { 671 return err 672 } 673 674 if err = tx.Exec(`DROP SEQUENCE IF EXISTS ?`, clause.Expr{SQL: sequenceName}).Error; err != nil { 675 return err 676 } 677 678 return 679 } 680 681 func (m Migrator) getColumnSequenceName(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field) ( 682 sequenceName string, err error) { 683 _, table := m.CurrentSchema(stmt, stmt.Table) 684 685 // DefaultValueValue is reset by ColumnTypes, search again. 686 var columnDefault string 687 err = tx.Raw( 688 `SELECT column_default FROM information_schema.columns WHERE table_name = ? AND column_name = ?`, 689 table, field.DBName).Scan(&columnDefault).Error 690 691 if err != nil { 692 return 693 } 694 695 sequenceName = strings.TrimSuffix( 696 strings.TrimPrefix(columnDefault, `nextval('`), 697 `'::regclass)`, 698 ) 699 return 700 } 701 702 func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { 703 indexes := make([]gorm.Index, 0) 704 705 err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 706 result := make([]*Index, 0) 707 scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error 708 if scanErr != nil { 709 return scanErr 710 } 711 indexMap := groupByIndexName(result) 712 for _, idx := range indexMap { 713 tempIdx := &migrator.Index{ 714 TableName: idx[0].TableName, 715 NameValue: idx[0].IndexName, 716 PrimaryKeyValue: sql.NullBool{ 717 Bool: idx[0].Primary, 718 Valid: true, 719 }, 720 UniqueValue: sql.NullBool{ 721 Bool: idx[0].NonUnique, 722 Valid: true, 723 }, 724 } 725 for _, x := range idx { 726 tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName) 727 } 728 indexes = append(indexes, tempIdx) 729 } 730 return nil 731 }) 732 return indexes, err 733 } 734 735 // Index table index info 736 type Index struct { 737 TableName string `gorm:"column:table_name"` 738 ColumnName string `gorm:"column:column_name"` 739 IndexName string `gorm:"column:index_name"` 740 NonUnique bool `gorm:"column:non_unique"` 741 Primary bool `gorm:"column:primary"` 742 } 743 744 func groupByIndexName(indexList []*Index) map[string][]*Index { 745 columnIndexMap := make(map[string][]*Index, len(indexList)) 746 for _, idx := range indexList { 747 columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx) 748 } 749 return columnIndexMap 750 } 751 752 func (m Migrator) GetTypeAliases(databaseTypeName string) []string { 753 return typeAliasMap[databaseTypeName] 754 } 755 756 // should reset prepared stmts when table changed 757 func (m Migrator) resetPreparedStmts() { 758 if m.DB.PrepareStmt { 759 if pdb, ok := m.DB.ConnPool.(*gorm.PreparedStmtDB); ok { 760 pdb.Reset() 761 } 762 } 763 } 764 765 func (m Migrator) DropColumn(dst interface{}, field string) error { 766 if err := m.Migrator.DropColumn(dst, field); err != nil { 767 return err 768 } 769 770 m.resetPreparedStmts() 771 return nil 772 } 773 774 func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error { 775 if err := m.Migrator.RenameColumn(dst, oldName, field); err != nil { 776 return err 777 } 778 779 m.resetPreparedStmts() 780 return nil 781 } 782 783 func parseDefaultValueValue(defaultValue string) string { 784 return regexp.MustCompile(`^(.*?)(?:::.*)?$`).ReplaceAllString(defaultValue, "$1") 785 }