github.com/wanlay/gorm-dm8@v1.0.5/migrator.go (about)

     1  package dm
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"strings"
     7  
     8  	"gorm.io/gorm"
     9  	"gorm.io/gorm/clause"
    10  	"gorm.io/gorm/migrator"
    11  	"gorm.io/gorm/schema"
    12  )
    13  
    14  type Migrator struct {
    15  	migrator.Migrator
    16  	Dialector
    17  }
    18  
    19  type BuildIndexOptionsInterface interface {
    20  	BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
    21  }
    22  
    23  func (m Migrator) CurrentDatabase() (name string) {
    24  	m.DB.Raw("SELECT SYS_CONTEXT ('userenv', 'current_schema') FROM DUAL").Row().Scan(&name)
    25  	return
    26  }
    27  
    28  func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
    29  	sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
    30  	if constraint.OnDelete != "" {
    31  		sql += " ON DELETE " + constraint.OnDelete
    32  	}
    33  
    34  	if constraint.OnUpdate != "" {
    35  		sql += " ON UPDATE " + constraint.OnUpdate
    36  	}
    37  
    38  	var foreignKeys, references []interface{}
    39  	for _, field := range constraint.ForeignKeys {
    40  		foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
    41  	}
    42  
    43  	for _, field := range constraint.References {
    44  		references = append(references, clause.Column{Name: field.DBName})
    45  	}
    46  	results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
    47  	return
    48  }
    49  
    50  func (m Migrator) CreateIndex(value interface{}, name string) error {
    51  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
    52  		if idx := stmt.Schema.LookIndex(name); idx != nil {
    53  			opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
    54  			values := []interface{}{clause.Column{Name: m.Migrator.DB.NamingStrategy.IndexName(stmt.Table, idx.Name)}, m.CurrentTable(stmt), opts}
    55  
    56  			createIndexSQL := "CREATE "
    57  			if idx.Class != "" {
    58  				createIndexSQL += idx.Class + " "
    59  			}
    60  			createIndexSQL += "INDEX ? ON ??"
    61  
    62  			if idx.Type != "" {
    63  				createIndexSQL += " USING " + idx.Type
    64  			}
    65  
    66  			// if idx.Comment != "" {
    67  			//	createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
    68  			// }
    69  
    70  			if idx.Option != "" {
    71  				createIndexSQL += " " + idx.Option
    72  			}
    73  
    74  			return m.DB.Exec(createIndexSQL, values...).Error
    75  		}
    76  
    77  		return fmt.Errorf("failed to create index with name %s", name)
    78  	})
    79  }
    80  
    81  func (m Migrator) CreateTable(values ...interface{}) error {
    82  	for _, value := range values {
    83  		m.TryQuotifyReservedWords(value)
    84  		m.TryRemoveOnUpdate(value)
    85  	}
    86  
    87  	for _, value := range m.ReorderModels(values, false) {
    88  		tx := m.DB.Session(&gorm.Session{})
    89  		if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
    90  			var (
    91  				createTableSQL          = "CREATE TABLE ? ("
    92  				values                  = []interface{}{m.CurrentTable(stmt)}
    93  				hasPrimaryKeyInDataType bool
    94  			)
    95  
    96  			for _, dbName := range stmt.Schema.DBNames {
    97  				field := stmt.Schema.FieldsByDBName[dbName]
    98  				if !field.IgnoreMigration {
    99  					createTableSQL += "? ?"
   100  					hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
   101  					f := m.DB.Migrator().FullDataTypeOf(field)
   102  					if field.AutoIncrement {
   103  						f.SQL = "INTEGER IDENTITY(1, " + strconv.FormatInt(field.AutoIncrementIncrement, 10) + ")"
   104  					}
   105  					values = append(values, clause.Column{Name: dbName}, f)
   106  					createTableSQL += ","
   107  				}
   108  			}
   109  
   110  			if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
   111  				createTableSQL += "PRIMARY KEY ?,"
   112  				primaryKeys := []interface{}{}
   113  				for _, field := range stmt.Schema.PrimaryFields {
   114  					primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
   115  				}
   116  
   117  				values = append(values, primaryKeys)
   118  			}
   119  
   120  			for _, idx := range stmt.Schema.ParseIndexes() {
   121  				if m.CreateIndexAfterCreateTable {
   122  					defer func(value interface{}, name string) {
   123  						if errr == nil {
   124  							// errr = tx.Migrator().CreateIndex(value, name)
   125  							errr = m.CreateIndex(value, name)
   126  						}
   127  					}(value, idx.Name)
   128  				} else {
   129  					if idx.Class != "" {
   130  						createTableSQL += idx.Class + " "
   131  					}
   132  					createTableSQL += "INDEX ? ?"
   133  
   134  					if idx.Comment != "" {
   135  						createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
   136  					}
   137  
   138  					if idx.Option != "" {
   139  						createTableSQL += " " + idx.Option
   140  					}
   141  
   142  					createTableSQL += ","
   143  					values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(migrator.BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
   144  				}
   145  			}
   146  
   147  			for _, rel := range stmt.Schema.Relationships.Relations {
   148  				if !m.DB.DisableForeignKeyConstraintWhenMigrating {
   149  					if constraint := rel.ParseConstraint(); constraint != nil {
   150  						if constraint.Schema == stmt.Schema {
   151  							sql, vars := buildConstraint(constraint)
   152  							createTableSQL += sql + ","
   153  							values = append(values, vars...)
   154  						}
   155  					}
   156  				}
   157  			}
   158  			for _, chk := range stmt.Schema.ParseCheckConstraints() {
   159  				createTableSQL += "CONSTRAINT ? CHECK (?),"
   160  				values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
   161  			}
   162  
   163  			createTableSQL = strings.TrimSuffix(createTableSQL, ",")
   164  
   165  			createTableSQL += ")"
   166  
   167  			if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
   168  				createTableSQL += fmt.Sprint(tableOption)
   169  			}
   170  
   171  			errr = tx.Exec(createTableSQL, values...).Error
   172  			return errr
   173  		}); err != nil {
   174  			return err
   175  		}
   176  	}
   177  	return nil
   178  }
   179  
   180  func (m Migrator) DropTable(values ...interface{}) error {
   181  	values = m.ReorderModels(values, false)
   182  	for i := len(values) - 1; i >= 0; i-- {
   183  		value := values[i]
   184  		tx := m.DB.Session(&gorm.Session{})
   185  		if m.HasTable(value) {
   186  			if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
   187  				return tx.Exec("DROP TABLE ? CASCADE CONSTRAINTS", clause.Table{Name: stmt.Table}).Error
   188  			}); err != nil {
   189  				return err
   190  			}
   191  		}
   192  	}
   193  	return nil
   194  }
   195  
   196  func (m Migrator) HasTable(value interface{}) bool {
   197  	var count int64
   198  
   199  	m.RunWithValue(value, func(stmt *gorm.Statement) error {
   200  		return m.DB.Raw("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = ?", stmt.Table).Row().Scan(&count)
   201  	})
   202  
   203  	return count > 0
   204  }
   205  
   206  func (m Migrator) RenameTable(oldName, newName interface{}) (err error) {
   207  	resolveTable := func(name interface{}) (result string, err error) {
   208  		if v, ok := name.(string); ok {
   209  			result = v
   210  		} else {
   211  			stmt := &gorm.Statement{DB: m.DB}
   212  			if err = stmt.Parse(name); err == nil {
   213  				result = stmt.Table
   214  			}
   215  		}
   216  		return
   217  	}
   218  
   219  	var oldTable, newTable string
   220  
   221  	if oldTable, err = resolveTable(oldName); err != nil {
   222  		return
   223  	}
   224  
   225  	if newTable, err = resolveTable(newName); err != nil {
   226  		return
   227  	}
   228  
   229  	if !m.HasTable(oldTable) {
   230  		return
   231  	}
   232  
   233  	return m.DB.Exec("RENAME TABLE ? TO ?",
   234  		clause.Table{Name: oldTable},
   235  		clause.Table{Name: newTable},
   236  	).Error
   237  }
   238  
   239  func (m Migrator) AddColumn(value interface{}, field string) error {
   240  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   241  		if field := stmt.Schema.LookUpField(field); field != nil {
   242  			return m.DB.Exec(
   243  				"ALTER TABLE ? ADD ? ?",
   244  				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field),
   245  			).Error
   246  		}
   247  		return fmt.Errorf("failed to look up field with name: %s", field)
   248  	})
   249  }
   250  
   251  func (m Migrator) DropColumn(value interface{}, name string) error {
   252  	if !m.HasColumn(value, name) {
   253  		return nil
   254  	}
   255  
   256  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   257  		if field := stmt.Schema.LookUpField(name); field != nil {
   258  			name = field.DBName
   259  		}
   260  
   261  		return m.DB.Exec(
   262  			"ALTER TABLE ? DROP ?",
   263  			clause.Table{Name: stmt.Table},
   264  			clause.Column{Name: name},
   265  		).Error
   266  	})
   267  }
   268  
   269  func (m Migrator) AlterColumn(value interface{}, field string) error {
   270  	if !m.HasColumn(value, field) {
   271  		return nil
   272  	}
   273  
   274  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   275  		if field := stmt.Schema.LookUpField(field); field != nil {
   276  			return m.DB.Exec(
   277  				"ALTER TABLE ? MODIFY ? ?",
   278  				clause.Table{Name: stmt.Table},
   279  				clause.Column{Name: field.DBName},
   280  				m.FullDataTypeOf(field),
   281  			).Error
   282  		}
   283  		return fmt.Errorf("failed to look up field with name: %s", field)
   284  	})
   285  }
   286  
   287  func (m Migrator) HasColumn(value interface{}, field string) bool {
   288  	var count int64
   289  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   290  		return m.DB.Raw("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field).Row().Scan(&count)
   291  	}) == nil && count > 0
   292  }
   293  
   294  func (m Migrator) CreateConstraint(value interface{}, name string) error {
   295  	m.TryRemoveOnUpdate(value)
   296  	return m.Migrator.CreateConstraint(value, name)
   297  }
   298  
   299  func (m Migrator) DropConstraint(value interface{}, name string) error {
   300  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   301  		for _, chk := range stmt.Schema.ParseCheckConstraints() {
   302  			if chk.Name == name {
   303  				return m.DB.Exec(
   304  					"ALTER TABLE ? DROP CHECK ?",
   305  					clause.Table{Name: stmt.Table}, clause.Column{Name: name},
   306  				).Error
   307  			}
   308  		}
   309  
   310  		return m.DB.Exec(
   311  			"ALTER TABLE ? DROP CONSTRAINT ?",
   312  			clause.Table{Name: stmt.Table}, clause.Column{Name: name},
   313  		).Error
   314  	})
   315  }
   316  
   317  func (m Migrator) HasConstraint(value interface{}, name string) bool {
   318  	var count int64
   319  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   320  		return m.DB.Raw(
   321  			"SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE TABLE_NAME = ? AND CONSTRAINT_NAME = ?", stmt.Table, name,
   322  		).Row().Scan(&count)
   323  	}) == nil && count > 0
   324  }
   325  
   326  func (m Migrator) DropIndex(value interface{}, name string) error {
   327  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   328  		if idx := stmt.Schema.LookIndex(name); idx != nil {
   329  			name = idx.Name
   330  		}
   331  
   332  		return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}, clause.Table{Name: stmt.Table}).Error
   333  	})
   334  }
   335  
   336  func (m Migrator) HasIndex(value interface{}, name string) bool {
   337  	var count int64
   338  	m.RunWithValue(value, func(stmt *gorm.Statement) error {
   339  		if idx := stmt.Schema.LookIndex(name); idx != nil {
   340  			name = idx.Name
   341  		}
   342  
   343  		return m.DB.Raw(
   344  			fmt.Sprintf(`SELECT COUNT(*) FROM USER_INDEXES WHERE TABLE_NAME = ('%s') AND INDEX_NAME = ('%s')`,
   345  				m.Migrator.DB.NamingStrategy.TableName(stmt.Table),
   346  				m.Migrator.DB.NamingStrategy.IndexName(stmt.Table, name),
   347  			),
   348  		).Row().Scan(&count)
   349  	})
   350  
   351  	return count > 0
   352  }
   353  
   354  // https://docs.dm.com/database/121/SPATL/alter-index-rename.htm
   355  func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
   356  	panic("TODO")
   357  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   358  		return m.DB.Exec(
   359  			"ALTER INDEX ?.? RENAME TO ?", // wat
   360  			clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
   361  		).Error
   362  	})
   363  }
   364  
   365  func (m Migrator) TryRemoveOnUpdate(values ...interface{}) error {
   366  	for _, value := range values {
   367  		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
   368  			for _, rel := range stmt.Schema.Relationships.Relations {
   369  				constraint := rel.ParseConstraint()
   370  				if constraint != nil {
   371  					rel.Field.TagSettings["CONSTRAINT"] = strings.ReplaceAll(rel.Field.TagSettings["CONSTRAINT"], fmt.Sprintf("ON UPDATE %s", constraint.OnUpdate), "")
   372  				}
   373  			}
   374  			return nil
   375  		}); err != nil {
   376  			return err
   377  		}
   378  	}
   379  	return nil
   380  }
   381  
   382  func (m Migrator) TryQuotifyReservedWords(values ...interface{}) error {
   383  	for _, value := range values {
   384  		if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
   385  			for idx, v := range stmt.Schema.DBNames {
   386  				if IsReservedWord(v) {
   387  					stmt.Schema.DBNames[idx] = fmt.Sprintf(`"%s"`, v)
   388  				}
   389  			}
   390  
   391  			for _, v := range stmt.Schema.Fields {
   392  				if IsReservedWord(v.DBName) {
   393  					v.DBName = fmt.Sprintf(`"%s"`, v.DBName)
   394  				}
   395  			}
   396  			return nil
   397  		}); err != nil {
   398  			return err
   399  		}
   400  	}
   401  	return nil
   402  }