github.com/wfusion/gofusion@v1.1.14/common/infra/drivers/orm/sqlite/migrator.go (about)

     1  package sqlite
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"strings"
     7  
     8  	"gorm.io/gorm"
     9  	"gorm.io/gorm/clause"
    10  	"gorm.io/gorm/migrator"
    11  	"gorm.io/gorm/schema"
    12  )
    13  
    14  type Migrator struct {
    15  	migrator.Migrator
    16  }
    17  
    18  func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
    19  	var enabled int
    20  	m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
    21  	if enabled == 1 {
    22  		m.DB.Exec("PRAGMA foreign_keys = OFF")
    23  		defer m.DB.Exec("PRAGMA foreign_keys = ON")
    24  	}
    25  
    26  	return fc()
    27  }
    28  
    29  func (m Migrator) HasTable(value interface{}) bool {
    30  	var count int
    31  	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
    32  		return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
    33  	})
    34  	return count > 0
    35  }
    36  
    37  func (m Migrator) DropTable(values ...interface{}) error {
    38  	return m.RunWithoutForeignKey(func() error {
    39  		values = m.ReorderModels(values, false)
    40  		tx := m.DB.Session(&gorm.Session{})
    41  
    42  		for i := len(values) - 1; i >= 0; i-- {
    43  			if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
    44  				return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
    45  			}); err != nil {
    46  				return err
    47  			}
    48  		}
    49  
    50  		return nil
    51  	})
    52  }
    53  
    54  func (m Migrator) GetTables() (tableList []string, err error) {
    55  	return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
    56  }
    57  
    58  func (m Migrator) HasColumn(value interface{}, name string) bool {
    59  	var count int
    60  	m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
    61  		if stmt.Schema != nil {
    62  			if field := stmt.Schema.LookUpField(name); field != nil {
    63  				name = field.DBName
    64  			}
    65  		}
    66  
    67  		if name != "" {
    68  			m.DB.Raw(
    69  				"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
    70  				"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
    71  			).Row().Scan(&count)
    72  		}
    73  		return nil
    74  	})
    75  	return count > 0
    76  }
    77  
    78  func (m Migrator) AlterColumn(value interface{}, name string) error {
    79  	return m.RunWithoutForeignKey(func() error {
    80  		return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
    81  			if field := stmt.Schema.LookUpField(name); field != nil {
    82  				var sqlArgs []interface{}
    83  				for i, f := range ddl.fields {
    84  					if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 1 && matches[1] == field.DBName {
    85  						ddl.fields[i] = fmt.Sprintf("`%v` ?", field.DBName)
    86  						sqlArgs = []interface{}{m.FullDataTypeOf(field)}
    87  						// table created by old version might look like `CREATE TABLE ? (? varchar(10) UNIQUE)`.
    88  						// FullDataTypeOf doesn't contain UNIQUE, so we need to add unique constraint.
    89  						if strings.Contains(strings.ToUpper(matches[3]), " UNIQUE") {
    90  							uniName := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
    91  							uni, _ := m.GuessConstraintInterfaceAndTable(stmt, uniName)
    92  							if uni != nil {
    93  								uniSQL, uniArgs := uni.Build()
    94  								ddl.addConstraint(uniName, uniSQL)
    95  								sqlArgs = append(sqlArgs, uniArgs...)
    96  							}
    97  						}
    98  						break
    99  					}
   100  				}
   101  				return ddl, sqlArgs, nil
   102  			}
   103  			return nil, nil, fmt.Errorf("failed to alter field with name %v", name)
   104  		})
   105  	})
   106  }
   107  
   108  // ColumnTypes return columnTypes []gorm.ColumnType and execErr error
   109  func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
   110  	columnTypes := make([]gorm.ColumnType, 0)
   111  	execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
   112  		var (
   113  			sqls   []string
   114  			sqlDDL *ddl
   115  		)
   116  
   117  		if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
   118  			return err
   119  		}
   120  
   121  		if sqlDDL, err = parseDDL(sqls...); err != nil {
   122  			return err
   123  		}
   124  
   125  		rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
   126  		if err != nil {
   127  			return err
   128  		}
   129  		defer func() {
   130  			err = rows.Close()
   131  		}()
   132  
   133  		var rawColumnTypes []*sql.ColumnType
   134  		rawColumnTypes, err = rows.ColumnTypes()
   135  		if err != nil {
   136  			return err
   137  		}
   138  
   139  		for _, c := range rawColumnTypes {
   140  			columnType := migrator.ColumnType{SQLColumnType: c}
   141  			for _, column := range sqlDDL.columns {
   142  				if column.NameValue.String == c.Name() {
   143  					column.SQLColumnType = c
   144  					columnType = column
   145  					break
   146  				}
   147  			}
   148  			columnTypes = append(columnTypes, columnType)
   149  		}
   150  
   151  		return err
   152  	})
   153  
   154  	return columnTypes, execErr
   155  }
   156  
   157  func (m Migrator) DropColumn(value interface{}, name string) error {
   158  	return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
   159  		if field := stmt.Schema.LookUpField(name); field != nil {
   160  			name = field.DBName
   161  		}
   162  
   163  		ddl.removeColumn(name)
   164  		return ddl, nil, nil
   165  	})
   166  }
   167  
   168  func (m Migrator) CreateConstraint(value interface{}, name string) error {
   169  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   170  		constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
   171  
   172  		return m.recreateTable(value, &table,
   173  			func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
   174  				var (
   175  					constraintName   string
   176  					constraintSql    string
   177  					constraintValues []interface{}
   178  				)
   179  
   180  				if constraint != nil {
   181  					constraintName = constraint.GetName()
   182  					constraintSql, constraintValues = constraint.Build()
   183  				} else {
   184  					return nil, nil, nil
   185  				}
   186  
   187  				ddl.addConstraint(constraintName, constraintSql)
   188  				return ddl, constraintValues, nil
   189  			})
   190  	})
   191  }
   192  
   193  func (m Migrator) DropConstraint(value interface{}, name string) error {
   194  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   195  		constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
   196  		if constraint != nil {
   197  			name = constraint.GetName()
   198  		}
   199  
   200  		return m.recreateTable(value, &table,
   201  			func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
   202  				ddl.removeConstraint(name)
   203  				return ddl, nil, nil
   204  			})
   205  	})
   206  }
   207  
   208  func (m Migrator) HasConstraint(value interface{}, name string) bool {
   209  	var count int64
   210  	m.RunWithValue(value, func(stmt *gorm.Statement) error {
   211  		constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
   212  		if constraint != nil {
   213  			name = constraint.GetName()
   214  		}
   215  
   216  		m.DB.Raw(
   217  			"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
   218  			"table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
   219  		).Row().Scan(&count)
   220  
   221  		return nil
   222  	})
   223  
   224  	return count > 0
   225  }
   226  
   227  func (m Migrator) CurrentDatabase() (name string) {
   228  	var null interface{}
   229  	m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
   230  	return
   231  }
   232  
   233  func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
   234  	for _, opt := range opts {
   235  		str := stmt.Quote(opt.DBName)
   236  		if opt.Expression != "" {
   237  			str = opt.Expression
   238  		}
   239  
   240  		if opt.Collate != "" {
   241  			str += " COLLATE " + opt.Collate
   242  		}
   243  
   244  		if opt.Sort != "" {
   245  			str += " " + opt.Sort
   246  		}
   247  		results = append(results, clause.Expr{SQL: str})
   248  	}
   249  	return
   250  }
   251  
   252  func (m Migrator) CreateIndex(value interface{}, name string) error {
   253  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   254  		if stmt.Schema != nil {
   255  			if idx := stmt.Schema.LookIndex(name); idx != nil {
   256  				opts := m.BuildIndexOptions(idx.Fields, stmt)
   257  				values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
   258  
   259  				createIndexSQL := "CREATE "
   260  				if idx.Class != "" {
   261  					createIndexSQL += idx.Class + " "
   262  				}
   263  				createIndexSQL += "INDEX ?"
   264  
   265  				if idx.Type != "" {
   266  					createIndexSQL += " USING " + idx.Type
   267  				}
   268  				createIndexSQL += " ON ??"
   269  
   270  				if idx.Where != "" {
   271  					createIndexSQL += " WHERE " + idx.Where
   272  				}
   273  
   274  				return m.DB.Exec(createIndexSQL, values...).Error
   275  			}
   276  		}
   277  		return fmt.Errorf("failed to create index with name %v", name)
   278  	})
   279  }
   280  
   281  func (m Migrator) HasIndex(value interface{}, name string) bool {
   282  	var count int
   283  	m.RunWithValue(value, func(stmt *gorm.Statement) error {
   284  		if stmt.Schema != nil {
   285  			if idx := stmt.Schema.LookIndex(name); idx != nil {
   286  				name = idx.Name
   287  			}
   288  		}
   289  
   290  		if name != "" {
   291  			m.DB.Raw(
   292  				"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
   293  			).Row().Scan(&count)
   294  		}
   295  		return nil
   296  	})
   297  	return count > 0
   298  }
   299  
   300  func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
   301  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   302  		var sql string
   303  		m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
   304  		if sql != "" {
   305  			if err := m.DropIndex(value, oldName); err != nil {
   306  				return err
   307  			}
   308  			return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
   309  		}
   310  		return fmt.Errorf("failed to find index with name %v", oldName)
   311  	})
   312  }
   313  
   314  func (m Migrator) DropIndex(value interface{}, name string) error {
   315  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   316  		if stmt.Schema != nil {
   317  			if idx := stmt.Schema.LookIndex(name); idx != nil {
   318  				name = idx.Name
   319  			}
   320  		}
   321  
   322  		return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
   323  	})
   324  }
   325  
   326  type Index struct {
   327  	Seq     int
   328  	Name    string
   329  	Unique  bool
   330  	Origin  string
   331  	Partial bool
   332  }
   333  
   334  // GetIndexes return Indexes []gorm.Index and execErr error,
   335  // See the [doc]
   336  //
   337  // [doc]: https://www.sqlite.org/pragma.html#pragma_index_list
   338  func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
   339  	indexes := make([]gorm.Index, 0)
   340  	err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
   341  		rst := make([]*Index, 0)
   342  		if err := m.DB.Debug().Raw("SELECT * FROM PRAGMA_index_list(?)", stmt.Table).Scan(&rst).Error; err != nil { // alias `PRAGMA index_list(?)`
   343  			return err
   344  		}
   345  		for _, index := range rst {
   346  			if index.Origin == "u" { // skip the index was created by a UNIQUE constraint
   347  				continue
   348  			}
   349  			var columns []string
   350  			if err := m.DB.Raw("SELECT name FROM PRAGMA_index_info(?)", index.Name).Scan(&columns).Error; err != nil { // alias `PRAGMA index_info(?)`
   351  				return err
   352  			}
   353  			indexes = append(indexes, &migrator.Index{
   354  				TableName:       stmt.Table,
   355  				NameValue:       index.Name,
   356  				ColumnList:      columns,
   357  				PrimaryKeyValue: sql.NullBool{Bool: index.Origin == "pk", Valid: true}, // The exceptions are INTEGER PRIMARY KEY
   358  				UniqueValue:     sql.NullBool{Bool: index.Unique, Valid: true},
   359  			})
   360  		}
   361  		return nil
   362  	})
   363  	return indexes, err
   364  }
   365  
   366  func (m Migrator) getRawDDL(table string) (string, error) {
   367  	var createSQL string
   368  	m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
   369  
   370  	if m.DB.Error != nil {
   371  		return "", m.DB.Error
   372  	}
   373  	return createSQL, nil
   374  }
   375  
   376  func (m Migrator) recreateTable(
   377  	value interface{}, tablePtr *string,
   378  	getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
   379  ) error {
   380  	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
   381  		table := stmt.Table
   382  		if tablePtr != nil {
   383  			table = *tablePtr
   384  		}
   385  
   386  		rawDDL, err := m.getRawDDL(table)
   387  		if err != nil {
   388  			return err
   389  		}
   390  
   391  		originDDL, err := parseDDL(rawDDL)
   392  		if err != nil {
   393  			return err
   394  		}
   395  
   396  		createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
   397  		if err != nil {
   398  			return err
   399  		}
   400  		if createDDL == nil {
   401  			return nil
   402  		}
   403  
   404  		newTableName := table + "__temp"
   405  		if err := createDDL.renameTable(newTableName, table); err != nil {
   406  			return err
   407  		}
   408  
   409  		columns := createDDL.getColumns()
   410  		createSQL := createDDL.compile()
   411  
   412  		return m.DB.Transaction(func(tx *gorm.DB) error {
   413  			if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
   414  				return err
   415  			}
   416  
   417  			queries := []string{
   418  				fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
   419  				fmt.Sprintf("DROP TABLE `%v`", table),
   420  				fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
   421  			}
   422  			for _, query := range queries {
   423  				if err := tx.Exec(query).Error; err != nil {
   424  					return err
   425  				}
   426  			}
   427  			return nil
   428  		})
   429  	})
   430  }