github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/sqlx/database.go (about) 1 package sqlx 2 3 import ( 4 "fmt" 5 "os" 6 "reflect" 7 8 "github.com/johnnyeven/libtools/sqlx/builder" 9 10 "github.com/sirupsen/logrus" 11 ) 12 13 func NewFeatureDatabase(name string) *Database { 14 if projectFeature, exists := os.LookupEnv("PROJECT_FEATURE"); exists && projectFeature != "" { 15 name = name + "__" + projectFeature 16 } 17 return NewDatabase(name) 18 } 19 20 func NewDatabase(name string) *Database { 21 return &Database{ 22 Database: builder.DB(name), 23 } 24 } 25 26 type Database struct { 27 *builder.Database 28 } 29 30 func (database *Database) Register(model Model) *builder.Table { 31 database.mustStructType(model) 32 rv := reflect.Indirect(reflect.ValueOf(model)) 33 table := builder.T(database.Database, model.TableName()) 34 ScanDefToTable(rv, table) 35 database.Database.Register(table) 36 return table 37 } 38 39 func (database Database) T(model Model) *builder.Table { 40 database.mustStructType(model) 41 return database.Database.Table(model.TableName()) 42 } 43 44 func (database Database) mustStructType(model Model) { 45 tpe := reflect.TypeOf(model) 46 if tpe.Kind() != reflect.Ptr { 47 panic(fmt.Errorf("model %s must be a pointer", tpe.Name())) 48 } 49 tpe = tpe.Elem() 50 if tpe.Kind() != reflect.Struct { 51 panic(fmt.Errorf("model %s must be a struct", tpe.Name())) 52 } 53 } 54 55 func (database *Database) Insert(model Model) *builder.StmtInsert { 56 table := database.T(model) 57 58 fieldValues := FieldValuesFromStructByNonZero(model) 59 60 if autoIncrementCol := table.AutoIncrement(); autoIncrementCol != nil { 61 delete(fieldValues, autoIncrementCol.FieldName) 62 } 63 64 cols, vals := table.ColumnsAndValuesByFieldValues(fieldValues) 65 66 return table.Insert().Columns(cols).Values(vals...) 67 } 68 69 func (database *Database) Update(model Model, zeroFields ...string) *builder.StmtUpdate { 70 table := database.T(model) 71 72 fieldValues := FieldValuesFromStructByNonZero(model, zeroFields...) 73 74 if autoIncrementCol := table.AutoIncrement(); autoIncrementCol != nil { 75 delete(fieldValues, autoIncrementCol.FieldName) 76 } 77 78 return table.Update().Set(table.AssignsByFieldValues(fieldValues)...) 79 } 80 81 func (database *Database) MustMigrateTo(db *DB, dryRun bool) { 82 if err := database.MigrateTo(db, dryRun); err != nil { 83 logrus.Panic(err) 84 } 85 } 86 87 func (database *Database) MigrateTo(db *DB, dryRun bool) error { 88 database.Register(&SqlMetaEnum{}) 89 90 currentDatabase := DBFromInformationSchema(db, database.Name, database.Tables.TableNames()...) 91 92 if !dryRun { 93 logrus.Debugf("=================== migrating database `%s` ====================", database.Name) 94 defer logrus.Debugf("=================== migrated database `%s` ====================", database.Name) 95 96 if currentDatabase == nil { 97 currentDatabase = &Database{ 98 Database: builder.DB(database.Name), 99 } 100 if err := db.Do(currentDatabase.Create(true)).Err(); err != nil { 101 return err 102 } 103 } 104 105 for name, table := range database.Tables { 106 currentTable := currentDatabase.Table(name) 107 if currentTable == nil { 108 if err := db.Do(table.Create(true)).Err(); err != nil { 109 return err 110 } 111 continue 112 } 113 114 stmt := currentTable.Diff(table) 115 if stmt != nil { 116 if err := db.Do(stmt).Err(); err != nil { 117 return err 118 } 119 continue 120 } 121 } 122 123 if err := database.SyncEnum(db); err != nil { 124 return err 125 } 126 127 return nil 128 } 129 130 if currentDatabase == nil { 131 currentDatabase = &Database{ 132 Database: builder.DB(database.Name), 133 } 134 135 fmt.Printf("=================== need to migrate database `%s` ====================\n", database.Name) 136 fmt.Println(currentDatabase.Create(true).Query) 137 fmt.Printf("=================== need to migrate database `%s` ====================\n", database.Name) 138 } 139 140 for name, table := range database.Tables { 141 currentTable := currentDatabase.Table(name) 142 if currentTable == nil { 143 fmt.Println(table.Create(true).Query) 144 continue 145 } 146 147 stmt := currentTable.Diff(table) 148 if stmt != nil { 149 fmt.Println(stmt.Query) 150 continue 151 } 152 } 153 154 if err := database.SyncEnum(db); err != nil { 155 return err 156 } 157 158 return nil 159 }