github.com/wfusion/gofusion@v1.1.14/common/infra/drivers/orm/sqlite/migrator.go (about) 1 package sqlite 2 3 import ( 4 "database/sql" 5 "fmt" 6 "strings" 7 8 "gorm.io/gorm" 9 "gorm.io/gorm/clause" 10 "gorm.io/gorm/migrator" 11 "gorm.io/gorm/schema" 12 ) 13 14 type Migrator struct { 15 migrator.Migrator 16 } 17 18 func (m *Migrator) RunWithoutForeignKey(fc func() error) error { 19 var enabled int 20 m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled) 21 if enabled == 1 { 22 m.DB.Exec("PRAGMA foreign_keys = OFF") 23 defer m.DB.Exec("PRAGMA foreign_keys = ON") 24 } 25 26 return fc() 27 } 28 29 func (m Migrator) HasTable(value interface{}) bool { 30 var count int 31 m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { 32 return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) 33 }) 34 return count > 0 35 } 36 37 func (m Migrator) DropTable(values ...interface{}) error { 38 return m.RunWithoutForeignKey(func() error { 39 values = m.ReorderModels(values, false) 40 tx := m.DB.Session(&gorm.Session{}) 41 42 for i := len(values) - 1; i >= 0; i-- { 43 if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { 44 return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error 45 }); err != nil { 46 return err 47 } 48 } 49 50 return nil 51 }) 52 } 53 54 func (m Migrator) GetTables() (tableList []string, err error) { 55 return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error 56 } 57 58 func (m Migrator) HasColumn(value interface{}, name string) bool { 59 var count int 60 m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { 61 if stmt.Schema != nil { 62 if field := stmt.Schema.LookUpField(name); field != nil { 63 name = field.DBName 64 } 65 } 66 67 if name != "" { 68 m.DB.Raw( 69 "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", 70 "table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%", 71 ).Row().Scan(&count) 72 } 73 return nil 74 }) 75 return count > 0 76 } 77 78 func (m Migrator) AlterColumn(value interface{}, name string) error { 79 return m.RunWithoutForeignKey(func() error { 80 return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { 81 if field := stmt.Schema.LookUpField(name); field != nil { 82 var sqlArgs []interface{} 83 for i, f := range ddl.fields { 84 if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 1 && matches[1] == field.DBName { 85 ddl.fields[i] = fmt.Sprintf("`%v` ?", field.DBName) 86 sqlArgs = []interface{}{m.FullDataTypeOf(field)} 87 // table created by old version might look like `CREATE TABLE ? (? varchar(10) UNIQUE)`. 88 // FullDataTypeOf doesn't contain UNIQUE, so we need to add unique constraint. 89 if strings.Contains(strings.ToUpper(matches[3]), " UNIQUE") { 90 uniName := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) 91 uni, _ := m.GuessConstraintInterfaceAndTable(stmt, uniName) 92 if uni != nil { 93 uniSQL, uniArgs := uni.Build() 94 ddl.addConstraint(uniName, uniSQL) 95 sqlArgs = append(sqlArgs, uniArgs...) 96 } 97 } 98 break 99 } 100 } 101 return ddl, sqlArgs, nil 102 } 103 return nil, nil, fmt.Errorf("failed to alter field with name %v", name) 104 }) 105 }) 106 } 107 108 // ColumnTypes return columnTypes []gorm.ColumnType and execErr error 109 func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { 110 columnTypes := make([]gorm.ColumnType, 0) 111 execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { 112 var ( 113 sqls []string 114 sqlDDL *ddl 115 ) 116 117 if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil { 118 return err 119 } 120 121 if sqlDDL, err = parseDDL(sqls...); err != nil { 122 return err 123 } 124 125 rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() 126 if err != nil { 127 return err 128 } 129 defer func() { 130 err = rows.Close() 131 }() 132 133 var rawColumnTypes []*sql.ColumnType 134 rawColumnTypes, err = rows.ColumnTypes() 135 if err != nil { 136 return err 137 } 138 139 for _, c := range rawColumnTypes { 140 columnType := migrator.ColumnType{SQLColumnType: c} 141 for _, column := range sqlDDL.columns { 142 if column.NameValue.String == c.Name() { 143 column.SQLColumnType = c 144 columnType = column 145 break 146 } 147 } 148 columnTypes = append(columnTypes, columnType) 149 } 150 151 return err 152 }) 153 154 return columnTypes, execErr 155 } 156 157 func (m Migrator) DropColumn(value interface{}, name string) error { 158 return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { 159 if field := stmt.Schema.LookUpField(name); field != nil { 160 name = field.DBName 161 } 162 163 ddl.removeColumn(name) 164 return ddl, nil, nil 165 }) 166 } 167 168 func (m Migrator) CreateConstraint(value interface{}, name string) error { 169 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 170 constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) 171 172 return m.recreateTable(value, &table, 173 func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { 174 var ( 175 constraintName string 176 constraintSql string 177 constraintValues []interface{} 178 ) 179 180 if constraint != nil { 181 constraintName = constraint.GetName() 182 constraintSql, constraintValues = constraint.Build() 183 } else { 184 return nil, nil, nil 185 } 186 187 ddl.addConstraint(constraintName, constraintSql) 188 return ddl, constraintValues, nil 189 }) 190 }) 191 } 192 193 func (m Migrator) DropConstraint(value interface{}, name string) error { 194 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 195 constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) 196 if constraint != nil { 197 name = constraint.GetName() 198 } 199 200 return m.recreateTable(value, &table, 201 func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { 202 ddl.removeConstraint(name) 203 return ddl, nil, nil 204 }) 205 }) 206 } 207 208 func (m Migrator) HasConstraint(value interface{}, name string) bool { 209 var count int64 210 m.RunWithValue(value, func(stmt *gorm.Statement) error { 211 constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) 212 if constraint != nil { 213 name = constraint.GetName() 214 } 215 216 m.DB.Raw( 217 "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", 218 "table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%", 219 ).Row().Scan(&count) 220 221 return nil 222 }) 223 224 return count > 0 225 } 226 227 func (m Migrator) CurrentDatabase() (name string) { 228 var null interface{} 229 m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) 230 return 231 } 232 233 func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { 234 for _, opt := range opts { 235 str := stmt.Quote(opt.DBName) 236 if opt.Expression != "" { 237 str = opt.Expression 238 } 239 240 if opt.Collate != "" { 241 str += " COLLATE " + opt.Collate 242 } 243 244 if opt.Sort != "" { 245 str += " " + opt.Sort 246 } 247 results = append(results, clause.Expr{SQL: str}) 248 } 249 return 250 } 251 252 func (m Migrator) CreateIndex(value interface{}, name string) error { 253 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 254 if stmt.Schema != nil { 255 if idx := stmt.Schema.LookIndex(name); idx != nil { 256 opts := m.BuildIndexOptions(idx.Fields, stmt) 257 values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} 258 259 createIndexSQL := "CREATE " 260 if idx.Class != "" { 261 createIndexSQL += idx.Class + " " 262 } 263 createIndexSQL += "INDEX ?" 264 265 if idx.Type != "" { 266 createIndexSQL += " USING " + idx.Type 267 } 268 createIndexSQL += " ON ??" 269 270 if idx.Where != "" { 271 createIndexSQL += " WHERE " + idx.Where 272 } 273 274 return m.DB.Exec(createIndexSQL, values...).Error 275 } 276 } 277 return fmt.Errorf("failed to create index with name %v", name) 278 }) 279 } 280 281 func (m Migrator) HasIndex(value interface{}, name string) bool { 282 var count int 283 m.RunWithValue(value, func(stmt *gorm.Statement) error { 284 if stmt.Schema != nil { 285 if idx := stmt.Schema.LookIndex(name); idx != nil { 286 name = idx.Name 287 } 288 } 289 290 if name != "" { 291 m.DB.Raw( 292 "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name, 293 ).Row().Scan(&count) 294 } 295 return nil 296 }) 297 return count > 0 298 } 299 300 func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { 301 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 302 var sql string 303 m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) 304 if sql != "" { 305 if err := m.DropIndex(value, oldName); err != nil { 306 return err 307 } 308 return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error 309 } 310 return fmt.Errorf("failed to find index with name %v", oldName) 311 }) 312 } 313 314 func (m Migrator) DropIndex(value interface{}, name string) error { 315 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 316 if stmt.Schema != nil { 317 if idx := stmt.Schema.LookIndex(name); idx != nil { 318 name = idx.Name 319 } 320 } 321 322 return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error 323 }) 324 } 325 326 type Index struct { 327 Seq int 328 Name string 329 Unique bool 330 Origin string 331 Partial bool 332 } 333 334 // GetIndexes return Indexes []gorm.Index and execErr error, 335 // See the [doc] 336 // 337 // [doc]: https://www.sqlite.org/pragma.html#pragma_index_list 338 func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { 339 indexes := make([]gorm.Index, 0) 340 err := m.RunWithValue(value, func(stmt *gorm.Statement) error { 341 rst := make([]*Index, 0) 342 if err := m.DB.Debug().Raw("SELECT * FROM PRAGMA_index_list(?)", stmt.Table).Scan(&rst).Error; err != nil { // alias `PRAGMA index_list(?)` 343 return err 344 } 345 for _, index := range rst { 346 if index.Origin == "u" { // skip the index was created by a UNIQUE constraint 347 continue 348 } 349 var columns []string 350 if err := m.DB.Raw("SELECT name FROM PRAGMA_index_info(?)", index.Name).Scan(&columns).Error; err != nil { // alias `PRAGMA index_info(?)` 351 return err 352 } 353 indexes = append(indexes, &migrator.Index{ 354 TableName: stmt.Table, 355 NameValue: index.Name, 356 ColumnList: columns, 357 PrimaryKeyValue: sql.NullBool{Bool: index.Origin == "pk", Valid: true}, // The exceptions are INTEGER PRIMARY KEY 358 UniqueValue: sql.NullBool{Bool: index.Unique, Valid: true}, 359 }) 360 } 361 return nil 362 }) 363 return indexes, err 364 } 365 366 func (m Migrator) getRawDDL(table string) (string, error) { 367 var createSQL string 368 m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL) 369 370 if m.DB.Error != nil { 371 return "", m.DB.Error 372 } 373 return createSQL, nil 374 } 375 376 func (m Migrator) recreateTable( 377 value interface{}, tablePtr *string, 378 getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error), 379 ) error { 380 return m.RunWithValue(value, func(stmt *gorm.Statement) error { 381 table := stmt.Table 382 if tablePtr != nil { 383 table = *tablePtr 384 } 385 386 rawDDL, err := m.getRawDDL(table) 387 if err != nil { 388 return err 389 } 390 391 originDDL, err := parseDDL(rawDDL) 392 if err != nil { 393 return err 394 } 395 396 createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt) 397 if err != nil { 398 return err 399 } 400 if createDDL == nil { 401 return nil 402 } 403 404 newTableName := table + "__temp" 405 if err := createDDL.renameTable(newTableName, table); err != nil { 406 return err 407 } 408 409 columns := createDDL.getColumns() 410 createSQL := createDDL.compile() 411 412 return m.DB.Transaction(func(tx *gorm.DB) error { 413 if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil { 414 return err 415 } 416 417 queries := []string{ 418 fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table), 419 fmt.Sprintf("DROP TABLE `%v`", table), 420 fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table), 421 } 422 for _, query := range queries { 423 if err := tx.Exec(query).Error; err != nil { 424 return err 425 } 426 } 427 return nil 428 }) 429 }) 430 }