github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/sqlx/database.go (about) 1 package sqlx 2 3 import ( 4 "fmt" 5 "os" 6 "reflect" 7 8 "github.com/sirupsen/logrus" 9 10 "github.com/artisanhe/tools/sqlx/builder" 11 ) 12 13 func NewFeatureDatabase(name string) *Database { 14 if projectFeature, exists := os.LookupEnv("PROJECT_FEATURE"); exists && projectFeature != "" { 15 name = name + "__" + projectFeature 16 } 17 return NewDatabase(name) 18 } 19 20 func NewDatabase(name string) *Database { 21 return &Database{ 22 Database: builder.DB(name), 23 } 24 } 25 26 type Database struct { 27 *builder.Database 28 } 29 30 func (database *Database) SetName(name string) { 31 database.Database.Name = name 32 } 33 34 func (database *Database) Register(model Model) *builder.Table { 35 database.mustStructType(model) 36 rv := reflect.Indirect(reflect.ValueOf(model)) 37 table := builder.T(database.Database, model.TableName()) 38 ScanDefToTable(rv, table) 39 database.Database.Register(table) 40 return table 41 } 42 43 func (database Database) T(model Model) *builder.Table { 44 database.mustStructType(model) 45 return database.Database.Table(model.TableName()) 46 } 47 48 func (database Database) mustStructType(model Model) { 49 tpe := reflect.TypeOf(model) 50 if tpe.Kind() != reflect.Ptr { 51 panic(fmt.Errorf("model %s must be a pointer", tpe.Name())) 52 } 53 tpe = tpe.Elem() 54 if tpe.Kind() != reflect.Struct { 55 panic(fmt.Errorf("model %s must be a struct", tpe.Name())) 56 } 57 } 58 59 func (database *Database) Insert(model Model) *builder.StmtInsert { 60 table := database.T(model) 61 62 fieldValues := FieldValuesFromStructByNonZero(model) 63 64 if autoIncrementCol := table.AutoIncrement(); autoIncrementCol != nil { 65 delete(fieldValues, autoIncrementCol.FieldName) 66 } 67 68 cols, vals := table.ColumnsAndValuesByFieldValues(fieldValues) 69 70 return table.Insert().Columns(cols).Values(vals...) 71 } 72 73 func (database *Database) Update(model Model, zeroFields ...string) *builder.StmtUpdate { 74 table := database.T(model) 75 76 fieldValues := FieldValuesFromStructByNonZero(model, zeroFields...) 77 78 if autoIncrementCol := table.AutoIncrement(); autoIncrementCol != nil { 79 delete(fieldValues, autoIncrementCol.FieldName) 80 } 81 82 return table.Update().Set(table.AssignsByFieldValues(fieldValues)...) 83 } 84 85 func (database *Database) MustMigrateTo(db *DB, dryRun bool) { 86 if err := database.MigrateTo(db, dryRun); err != nil { 87 logrus.Panic(err) 88 } 89 } 90 91 func (database *Database) MigrateTo(db *DB, dryRun bool) error { 92 database.Register(&SqlMetaEnum{}) 93 94 currentDatabase := DBFromInformationSchema(db, database.Name, database.Tables.TableNames()...) 95 96 if !dryRun { 97 logrus.Debugf("=================== migrating database `%s` ====================", database.Name) 98 defer logrus.Debugf("=================== migrated database `%s` ====================", database.Name) 99 100 if currentDatabase == nil { 101 currentDatabase = &Database{ 102 Database: builder.DB(database.Name), 103 } 104 if err := db.Do(currentDatabase.Create(true)).Err(); err != nil { 105 return err 106 } 107 } 108 109 for name, table := range database.Tables { 110 currentTable := currentDatabase.Table(name) 111 if currentTable == nil { 112 if err := db.Do(table.Create(true)).Err(); err != nil { 113 return err 114 } 115 continue 116 } 117 118 stmt := currentTable.Diff(table) 119 if stmt != nil { 120 if err := db.Do(stmt).Err(); err != nil { 121 return err 122 } 123 continue 124 } 125 } 126 127 if err := database.SyncEnum(db); err != nil { 128 return err 129 } 130 131 return nil 132 } 133 134 if currentDatabase == nil { 135 currentDatabase = &Database{ 136 Database: builder.DB(database.Name), 137 } 138 139 fmt.Printf("=================== need to migrate database `%s` ====================\n", database.Name) 140 fmt.Println(currentDatabase.Create(true).Query) 141 fmt.Printf("=================== need to migrate database `%s` ====================\n", database.Name) 142 } 143 144 for name, table := range database.Tables { 145 currentTable := currentDatabase.Table(name) 146 if currentTable == nil { 147 fmt.Println(table.Create(true).Query) 148 continue 149 } 150 151 stmt := currentTable.Diff(table) 152 if stmt != nil { 153 fmt.Println(stmt.Query) 154 continue 155 } 156 } 157 158 if err := database.SyncEnum(db); err != nil { 159 return err 160 } 161 162 return nil 163 }