github.com/wfusion/gofusion@v1.1.14/common/infra/drivers/orm/opengauss/migrator.go (about)

     1  package opengauss
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"regexp"
     7  	"strings"
     8  
     9  	"gorm.io/gorm"
    10  	"gorm.io/gorm/clause"
    11  	"gorm.io/gorm/migrator"
    12  	"gorm.io/gorm/schema"
    13  )
    14  
    15  const indexSql = `
    16  select
    17      t.relname as table_name,
    18      i.relname as index_name,
    19      a.attname as column_name,
    20      ix.indisunique as non_unique,
    21  	ix.indisprimary as primary
    22  from
    23      pg_class t,
    24      pg_class i,
    25      pg_index ix,
    26      pg_attribute a
    27  where
    28      t.oid = ix.indrelid
    29      and i.oid = ix.indexrelid
    30      and a.attrelid = t.oid
    31      and a.attnum = ANY(ix.indkey)
    32      and t.relkind = 'r'
    33      and t.relname = ?
    34  `
    35  
    36  var typeAliasMap = map[string][]string{
    37  	"int2":                     {"smallint"},
    38  	"int4":                     {"integer"},
    39  	"int8":                     {"bigint"},
    40  	"smallint":                 {"int2"},
    41  	"integer":                  {"int4"},
    42  	"bigint":                   {"int8"},
    43  	"decimal":                  {"numeric"},
    44  	"numeric":                  {"decimal"},
    45  	"timestamptz":              {"timestamp with time zone"},
    46  	"timestamp with time zone": {"timestamptz"},
    47  }
    48  
    49  type Migrator struct {
    50  	migrator.Migrator
    51  }
    52  
    53  func (m Migrator) CurrentDatabase() (name string) {
    54  	m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name)
    55  	return
    56  }
    57  
    58  func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
    59  	for _, opt := range opts {
    60  		str := stmt.Quote(opt.DBName)
    61  		if opt.Expression != "" {
    62  			str = opt.Expression
    63  		}
    64  
    65  		if opt.Collate != "" {
    66  			str += " COLLATE " + opt.Collate
    67  		}
    68  
    69  		if opt.Sort != "" {
    70  			str += " " + opt.Sort
    71  		}
    72  		results = append(results, clause.Expr{SQL: str})
    73  	}
    74  	return
    75  }
    76  
    77  func (m Migrator) HasIndex(value interface{}, name string) bool {
    78  	var count int64
    79  	m.RunWithValue(value, func(stmt *gorm.Statement) error {
    80  		if stmt.Schema != nil {
    81  			if idx := stmt.Schema.LookIndex(name); idx != nil {
    82  				name = idx.Name
    83  			}
    84  		}
    85  		currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
    86  		return m.DB.Raw(
    87  			"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema,
    88  		).Scan(&count).Error
    89  	})
    90  
    91  	return count > 0
    92  }
    93  
    94  func (m Migrator) CreateIndex(value interface{}, name string) error {
    95  	if !m.HasIndex(value, name) {
    96  		return nil
    97  	}
    98  
    99  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   100  		if stmt.Schema != nil {
   101  			if idx := stmt.Schema.LookIndex(name); idx != nil {
   102  				opts := m.BuildIndexOptions(idx.Fields, stmt)
   103  				values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
   104  				createIndexSQL := "CREATE "
   105  				if idx.Class != "" {
   106  					createIndexSQL += idx.Class + " "
   107  				}
   108  				createIndexSQL += "INDEX "
   109  
   110  				if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" {
   111  					createIndexSQL += "CONCURRENTLY "
   112  				}
   113  
   114  				createIndexSQL += "? ON ?"
   115  
   116  				if idx.Type != "" {
   117  					createIndexSQL += " USING " + idx.Type + "(?)"
   118  				} else {
   119  					createIndexSQL += " ?"
   120  				}
   121  
   122  				if idx.Where != "" {
   123  					createIndexSQL += " WHERE " + idx.Where
   124  				}
   125  
   126  				return m.DB.Exec(createIndexSQL, values...).Error
   127  			}
   128  		}
   129  
   130  		return fmt.Errorf("failed to create index with name %v", name)
   131  	})
   132  }
   133  
   134  func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
   135  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   136  		return m.DB.Exec(
   137  			"ALTER INDEX ? RENAME TO ?",
   138  			clause.Column{Name: oldName}, clause.Column{Name: newName},
   139  		).Error
   140  	})
   141  }
   142  
   143  func (m Migrator) DropIndex(value interface{}, name string) error {
   144  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   145  		if stmt.Schema != nil {
   146  			if idx := stmt.Schema.LookIndex(name); idx != nil {
   147  				name = idx.Name
   148  			}
   149  		}
   150  
   151  		return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
   152  	})
   153  }
   154  
   155  func (m Migrator) GetTables() (tableList []string, err error) {
   156  	currentSchema, _ := m.CurrentSchema(m.DB.Statement, "")
   157  	return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
   158  }
   159  
   160  func (m Migrator) CreateTable(values ...interface{}) (err error) {
   161  	if err = m.Migrator.CreateTable(values...); err != nil {
   162  		return
   163  	}
   164  	for _, value := range m.ReorderModels(values, false) {
   165  		if err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
   166  			if stmt.Schema != nil {
   167  				for _, fieldName := range stmt.Schema.DBNames {
   168  					field := stmt.Schema.FieldsByDBName[fieldName]
   169  					if field.Comment != "" {
   170  						if err := m.DB.Exec(
   171  							"COMMENT ON COLUMN ?.? IS ?",
   172  							m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
   173  						).Error; err != nil {
   174  							return err
   175  						}
   176  					}
   177  				}
   178  			}
   179  			return nil
   180  		}); err != nil {
   181  			return
   182  		}
   183  	}
   184  	return
   185  }
   186  
   187  func (m Migrator) HasTable(value interface{}) bool {
   188  	var count int64
   189  	m.RunWithValue(value, func(stmt *gorm.Statement) error {
   190  		currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
   191  		return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
   192  	})
   193  	return count > 0
   194  }
   195  
   196  func (m Migrator) DropTable(values ...interface{}) error {
   197  	values = m.ReorderModels(values, false)
   198  	tx := m.DB.Session(&gorm.Session{})
   199  	for i := len(values) - 1; i >= 0; i-- {
   200  		if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
   201  			return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", m.CurrentTable(stmt)).Error
   202  		}); err != nil {
   203  			return err
   204  		}
   205  	}
   206  	return nil
   207  }
   208  
   209  func (m Migrator) AddColumn(value interface{}, field string) error {
   210  	if err := m.Migrator.AddColumn(value, field); err != nil {
   211  		return err
   212  	}
   213  	m.resetPreparedStmts()
   214  
   215  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   216  		if stmt.Schema != nil {
   217  			if field := stmt.Schema.LookUpField(field); field != nil {
   218  				if field.Comment != "" {
   219  					if err := m.DB.Exec(
   220  						"COMMENT ON COLUMN ?.? IS ?",
   221  						m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
   222  					).Error; err != nil {
   223  						return err
   224  					}
   225  				}
   226  			}
   227  		}
   228  		return nil
   229  	})
   230  }
   231  
   232  func (m Migrator) HasColumn(value interface{}, field string) bool {
   233  	var count int64
   234  	m.RunWithValue(value, func(stmt *gorm.Statement) error {
   235  		name := field
   236  		if stmt.Schema != nil {
   237  			if field := stmt.Schema.LookUpField(field); field != nil {
   238  				name = field.DBName
   239  			}
   240  		}
   241  
   242  		currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
   243  		return m.DB.Raw(
   244  			"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
   245  			currentSchema, curTable, name,
   246  		).Scan(&count).Error
   247  	})
   248  
   249  	return count > 0
   250  }
   251  
   252  func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
   253  	// skip primary field
   254  	if !field.PrimaryKey {
   255  		if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil {
   256  			return err
   257  		}
   258  	}
   259  
   260  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   261  		var description string
   262  		currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
   263  		values := []interface{}{currentSchema, curTable, field.DBName, stmt.Table, currentSchema}
   264  		checkSQL := "SELECT description FROM pg_catalog.pg_description "
   265  		checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) "
   266  		checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = "
   267  		checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))"
   268  		m.DB.Raw(checkSQL, values...).Scan(&description)
   269  
   270  		comment := strings.Trim(field.Comment, "'")
   271  		comment = strings.Trim(comment, `"`)
   272  		if field.Comment != "" && comment != description {
   273  			if err := m.DB.Exec(
   274  				"COMMENT ON COLUMN ?.? IS ?",
   275  				m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
   276  			).Error; err != nil {
   277  				return err
   278  			}
   279  		}
   280  		return nil
   281  	})
   282  }
   283  
   284  // AlterColumn alter value's `field` column' type based on schema definition
   285  func (m Migrator) AlterColumn(value interface{}, field string) error {
   286  	err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
   287  		if stmt.Schema != nil {
   288  			if field := stmt.Schema.LookUpField(field); field != nil {
   289  				var (
   290  					columnTypes, _  = m.DB.Migrator().ColumnTypes(value)
   291  					fieldColumnType *migrator.ColumnType
   292  				)
   293  				for _, columnType := range columnTypes {
   294  					if columnType.Name() == field.DBName {
   295  						fieldColumnType, _ = columnType.(*migrator.ColumnType)
   296  					}
   297  				}
   298  
   299  				fileType := clause.Expr{SQL: m.DataTypeOf(field)}
   300  				// check for typeName and SQL name
   301  				isSameType := true
   302  				if fieldColumnType.DatabaseTypeName() != fileType.SQL {
   303  					isSameType = false
   304  					// if different, also check for aliases
   305  					aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName())
   306  					for _, alias := range aliases {
   307  						if strings.HasPrefix(fileType.SQL, alias) {
   308  							isSameType = true
   309  							break
   310  						}
   311  					}
   312  				}
   313  
   314  				// not same, migrate
   315  				if !isSameType {
   316  					filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement()
   317  					if field.AutoIncrement && filedColumnAutoIncrement { // update
   318  						serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
   319  						if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType {
   320  							if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
   321  								return err
   322  							}
   323  						}
   324  					} else if field.AutoIncrement && !filedColumnAutoIncrement { // create
   325  						serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
   326  						if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
   327  							return err
   328  						}
   329  					} else if !field.AutoIncrement && filedColumnAutoIncrement { // delete
   330  						if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil {
   331  							return err
   332  						}
   333  					} else {
   334  						if err := m.modifyColumn(stmt, field, fileType, fieldColumnType); err != nil {
   335  							return err
   336  						}
   337  					}
   338  				}
   339  
   340  				if null, _ := fieldColumnType.Nullable(); null == field.NotNull {
   341  					if field.NotNull {
   342  						if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
   343  							return err
   344  						}
   345  					} else {
   346  						if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
   347  							return err
   348  						}
   349  					}
   350  				}
   351  
   352  				if uniq, _ := fieldColumnType.Unique(); !uniq && field.Unique {
   353  					idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)}
   354  					// Not a unique constraint but a unique index
   355  					if !m.HasIndex(stmt.Table, idxName.Name) {
   356  						if err := m.DB.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil {
   357  							return err
   358  						}
   359  					}
   360  				}
   361  
   362  				if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue {
   363  					if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
   364  						if field.DefaultValueInterface != nil {
   365  							defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
   366  							m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
   367  							if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil {
   368  								return err
   369  							}
   370  						} else if field.DefaultValue != "(-)" {
   371  							if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
   372  								return err
   373  							}
   374  						} else {
   375  							if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
   376  								return err
   377  							}
   378  						}
   379  					}
   380  				}
   381  				return nil
   382  			}
   383  		}
   384  		return fmt.Errorf("failed to look up field with name: %s", field)
   385  	})
   386  
   387  	if err != nil {
   388  		return err
   389  	}
   390  	m.resetPreparedStmts()
   391  	return nil
   392  }
   393  
   394  func (m Migrator) modifyColumn(stmt *gorm.Statement, field *schema.Field, targetType clause.Expr, existingColumn *migrator.ColumnType) error {
   395  	alterSQL := "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?"
   396  	isUncastableDefaultValue := false
   397  
   398  	if targetType.SQL == "boolean" {
   399  		switch existingColumn.DatabaseTypeName() {
   400  		case "int2", "int8", "numeric":
   401  			alterSQL = "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::int::?"
   402  		}
   403  		isUncastableDefaultValue = true
   404  	}
   405  
   406  	if dv, _ := existingColumn.DefaultValue(); dv != "" && isUncastableDefaultValue {
   407  		if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
   408  			return err
   409  		}
   410  	}
   411  	if err := m.DB.Exec(alterSQL, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, targetType, clause.Column{Name: field.DBName}, targetType).Error; err != nil {
   412  		return err
   413  	}
   414  	return nil
   415  }
   416  
   417  func (m Migrator) HasConstraint(value interface{}, name string) bool {
   418  	var count int64
   419  	m.RunWithValue(value, func(stmt *gorm.Statement) error {
   420  		constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
   421  		currentSchema, curTable := m.CurrentSchema(stmt, table)
   422  		if constraint != nil {
   423  			name = constraint.Name
   424  		} else if chk != nil {
   425  			name = chk.Name
   426  		}
   427  
   428  		return m.DB.Raw(
   429  			"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?",
   430  			currentSchema, curTable, name,
   431  		).Scan(&count).Error
   432  	})
   433  
   434  	return count > 0
   435  }
   436  
   437  func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) {
   438  	columnTypes = make([]gorm.ColumnType, 0)
   439  	err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
   440  		var (
   441  			currentDatabase      = m.DB.Migrator().CurrentDatabase()
   442  			currentSchema, table = m.CurrentSchema(stmt, stmt.Table)
   443  			columns, err         = m.DB.Raw(
   444  				"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
   445  				currentDatabase, currentSchema, table).Rows()
   446  		)
   447  
   448  		if err != nil {
   449  			return err
   450  		}
   451  
   452  		for columns.Next() {
   453  			var (
   454  				column = &migrator.ColumnType{
   455  					PrimaryKeyValue: sql.NullBool{Valid: true},
   456  					UniqueValue:     sql.NullBool{Valid: true},
   457  				}
   458  				datetimePrecision sql.NullInt64
   459  				radixValue        sql.NullInt64
   460  				typeLenValue      sql.NullInt64
   461  				identityIncrement sql.NullString
   462  			)
   463  
   464  			err = columns.Scan(
   465  				&column.NameValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.DecimalSizeValue,
   466  				&radixValue, &column.ScaleValue, &datetimePrecision, &typeLenValue, &column.DefaultValueValue, &column.CommentValue, &identityIncrement,
   467  			)
   468  			if err != nil {
   469  				return err
   470  			}
   471  
   472  			if typeLenValue.Valid && typeLenValue.Int64 > 0 {
   473  				column.LengthValue = typeLenValue
   474  			}
   475  
   476  			if (strings.HasPrefix(column.DefaultValueValue.String, "nextval('") &&
   477  				strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)")) || (identityIncrement.Valid && identityIncrement.String != "") {
   478  				column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
   479  				column.DefaultValueValue = sql.NullString{}
   480  			}
   481  
   482  			if column.DefaultValueValue.Valid {
   483  				column.DefaultValueValue.String = parseDefaultValueValue(column.DefaultValueValue.String)
   484  			}
   485  
   486  			if datetimePrecision.Valid {
   487  				column.DecimalSizeValue = datetimePrecision
   488  			}
   489  
   490  			columnTypes = append(columnTypes, column)
   491  		}
   492  		columns.Close()
   493  
   494  		// assign sql column type
   495  		{
   496  			rows, rowsErr := m.GetRows(currentSchema, table)
   497  			if rowsErr != nil {
   498  				return rowsErr
   499  			}
   500  			rawColumnTypes, err := rows.ColumnTypes()
   501  			if err != nil {
   502  				return err
   503  			}
   504  			for _, columnType := range columnTypes {
   505  				for _, c := range rawColumnTypes {
   506  					if c.Name() == columnType.Name() {
   507  						columnType.(*migrator.ColumnType).SQLColumnType = c
   508  						break
   509  					}
   510  				}
   511  			}
   512  			rows.Close()
   513  		}
   514  
   515  		// check primary, unique field
   516  		{
   517  			columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
   518  			if err != nil {
   519  				return err
   520  			}
   521  			uniqueContraints := map[string]int{}
   522  			for columnTypeRows.Next() {
   523  				var constraintName string
   524  				columnTypeRows.Scan(&constraintName)
   525  				uniqueContraints[constraintName]++
   526  			}
   527  			columnTypeRows.Close()
   528  
   529  			columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
   530  			if err != nil {
   531  				return err
   532  			}
   533  			for columnTypeRows.Next() {
   534  				var name, constraintName, columnType string
   535  				columnTypeRows.Scan(&name, &constraintName, &columnType)
   536  				for _, c := range columnTypes {
   537  					mc := c.(*migrator.ColumnType)
   538  					if mc.NameValue.String == name {
   539  						switch columnType {
   540  						case "PRIMARY KEY":
   541  							mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
   542  						case "UNIQUE":
   543  							if uniqueContraints[constraintName] == 1 {
   544  								mc.UniqueValue = sql.NullBool{Bool: true, Valid: true}
   545  							}
   546  						}
   547  						break
   548  					}
   549  				}
   550  			}
   551  			columnTypeRows.Close()
   552  		}
   553  
   554  		// check column type
   555  		{
   556  			dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
   557  		FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
   558  		WHERE a.attnum > 0 -- hide internal columns
   559  		AND NOT a.attisdropped -- hide deleted columns
   560  		AND b.relname = ?`, currentSchema, table).Rows()
   561  			if err != nil {
   562  				return err
   563  			}
   564  
   565  			for dataTypeRows.Next() {
   566  				var name, dataType string
   567  				dataTypeRows.Scan(&name, &dataType)
   568  				for _, c := range columnTypes {
   569  					mc := c.(*migrator.ColumnType)
   570  					if mc.NameValue.String == name {
   571  						mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true}
   572  						// Handle array type: _text -> text[] , _int4 -> integer[]
   573  						// Not support array size limits and array size limits because:
   574  						// https://www.postgresql.org/docs/current/arrays.html#ARRAYS-DECLARATION
   575  						if strings.HasPrefix(mc.DataTypeValue.String, "_") {
   576  							mc.DataTypeValue = sql.NullString{String: dataType, Valid: true}
   577  						}
   578  						break
   579  					}
   580  				}
   581  			}
   582  			dataTypeRows.Close()
   583  		}
   584  
   585  		return err
   586  	})
   587  	return
   588  }
   589  
   590  func (m Migrator) GetRows(currentSchema interface{}, table interface{}) (*sql.Rows, error) {
   591  	name := table.(string)
   592  	if _, ok := currentSchema.(string); ok {
   593  		name = fmt.Sprintf("%v.%v", currentSchema, table)
   594  	}
   595  
   596  	return m.DB.Session(&gorm.Session{}).Table(name).Limit(1).Rows()
   597  }
   598  
   599  func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (interface{}, interface{}) {
   600  	if strings.Contains(table, ".") {
   601  		if tables := strings.Split(table, `.`); len(tables) == 2 {
   602  			return tables[0], tables[1]
   603  		}
   604  	}
   605  
   606  	if stmt.TableExpr != nil {
   607  		if tables := strings.Split(stmt.TableExpr.SQL, `"."`); len(tables) == 2 {
   608  			return strings.TrimPrefix(tables[0], `"`), table
   609  		}
   610  	}
   611  	return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table
   612  }
   613  
   614  func (m Migrator) CreateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
   615  	serialDatabaseType string) (err error) {
   616  
   617  	_, table := m.CurrentSchema(stmt, stmt.Table)
   618  	tableName := table.(string)
   619  
   620  	sequenceName := strings.Join([]string{tableName, field.DBName, "seq"}, "_")
   621  	if err = tx.Exec(`CREATE SEQUENCE IF NOT EXISTS ? AS ?`, clause.Expr{SQL: sequenceName},
   622  		clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
   623  		return err
   624  	}
   625  
   626  	if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT nextval('?')",
   627  		clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}, clause.Expr{SQL: sequenceName}).Error; err != nil {
   628  		return err
   629  	}
   630  
   631  	if err := tx.Exec("ALTER SEQUENCE ? OWNED BY ?.?",
   632  		clause.Expr{SQL: sequenceName}, clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}).Error; err != nil {
   633  		return err
   634  	}
   635  	return
   636  }
   637  
   638  func (m Migrator) UpdateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
   639  	serialDatabaseType string) (err error) {
   640  
   641  	sequenceName, err := m.getColumnSequenceName(tx, stmt, field)
   642  	if err != nil {
   643  		return err
   644  	}
   645  
   646  	if err = tx.Exec(`ALTER SEQUENCE IF EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
   647  		return err
   648  	}
   649  
   650  	if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?",
   651  		m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
   652  		return err
   653  	}
   654  	return
   655  }
   656  
   657  func (m Migrator) DeleteSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
   658  	fileType clause.Expr) (err error) {
   659  
   660  	sequenceName, err := m.getColumnSequenceName(tx, stmt, field)
   661  	if err != nil {
   662  		return err
   663  	}
   664  
   665  	if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil {
   666  		return err
   667  	}
   668  
   669  	if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT",
   670  		m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}).Error; err != nil {
   671  		return err
   672  	}
   673  
   674  	if err = tx.Exec(`DROP SEQUENCE IF EXISTS ?`, clause.Expr{SQL: sequenceName}).Error; err != nil {
   675  		return err
   676  	}
   677  
   678  	return
   679  }
   680  
   681  func (m Migrator) getColumnSequenceName(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field) (
   682  	sequenceName string, err error) {
   683  	_, table := m.CurrentSchema(stmt, stmt.Table)
   684  
   685  	// DefaultValueValue is reset by ColumnTypes, search again.
   686  	var columnDefault string
   687  	err = tx.Raw(
   688  		`SELECT column_default FROM information_schema.columns WHERE table_name = ? AND column_name = ?`,
   689  		table, field.DBName).Scan(&columnDefault).Error
   690  
   691  	if err != nil {
   692  		return
   693  	}
   694  
   695  	sequenceName = strings.TrimSuffix(
   696  		strings.TrimPrefix(columnDefault, `nextval('`),
   697  		`'::regclass)`,
   698  	)
   699  	return
   700  }
   701  
   702  func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
   703  	indexes := make([]gorm.Index, 0)
   704  
   705  	err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
   706  		result := make([]*Index, 0)
   707  		scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error
   708  		if scanErr != nil {
   709  			return scanErr
   710  		}
   711  		indexMap := groupByIndexName(result)
   712  		for _, idx := range indexMap {
   713  			tempIdx := &migrator.Index{
   714  				TableName: idx[0].TableName,
   715  				NameValue: idx[0].IndexName,
   716  				PrimaryKeyValue: sql.NullBool{
   717  					Bool:  idx[0].Primary,
   718  					Valid: true,
   719  				},
   720  				UniqueValue: sql.NullBool{
   721  					Bool:  idx[0].NonUnique,
   722  					Valid: true,
   723  				},
   724  			}
   725  			for _, x := range idx {
   726  				tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName)
   727  			}
   728  			indexes = append(indexes, tempIdx)
   729  		}
   730  		return nil
   731  	})
   732  	return indexes, err
   733  }
   734  
   735  // Index table index info
   736  type Index struct {
   737  	TableName  string `gorm:"column:table_name"`
   738  	ColumnName string `gorm:"column:column_name"`
   739  	IndexName  string `gorm:"column:index_name"`
   740  	NonUnique  bool   `gorm:"column:non_unique"`
   741  	Primary    bool   `gorm:"column:primary"`
   742  }
   743  
   744  func groupByIndexName(indexList []*Index) map[string][]*Index {
   745  	columnIndexMap := make(map[string][]*Index, len(indexList))
   746  	for _, idx := range indexList {
   747  		columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx)
   748  	}
   749  	return columnIndexMap
   750  }
   751  
   752  func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
   753  	return typeAliasMap[databaseTypeName]
   754  }
   755  
   756  // should reset prepared stmts when table changed
   757  func (m Migrator) resetPreparedStmts() {
   758  	if m.DB.PrepareStmt {
   759  		if pdb, ok := m.DB.ConnPool.(*gorm.PreparedStmtDB); ok {
   760  			pdb.Reset()
   761  		}
   762  	}
   763  }
   764  
   765  func (m Migrator) DropColumn(dst interface{}, field string) error {
   766  	if err := m.Migrator.DropColumn(dst, field); err != nil {
   767  		return err
   768  	}
   769  
   770  	m.resetPreparedStmts()
   771  	return nil
   772  }
   773  
   774  func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error {
   775  	if err := m.Migrator.RenameColumn(dst, oldName, field); err != nil {
   776  		return err
   777  	}
   778  
   779  	m.resetPreparedStmts()
   780  	return nil
   781  }
   782  
   783  func parseDefaultValueValue(defaultValue string) string {
   784  	return regexp.MustCompile(`^(.*?)(?:::.*)?$`).ReplaceAllString(defaultValue, "$1")
   785  }