github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/sqlx/database.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"reflect"
     7  
     8  	"github.com/sirupsen/logrus"
     9  
    10  	"github.com/artisanhe/tools/sqlx/builder"
    11  )
    12  
    13  func NewFeatureDatabase(name string) *Database {
    14  	if projectFeature, exists := os.LookupEnv("PROJECT_FEATURE"); exists && projectFeature != "" {
    15  		name = name + "__" + projectFeature
    16  	}
    17  	return NewDatabase(name)
    18  }
    19  
    20  func NewDatabase(name string) *Database {
    21  	return &Database{
    22  		Database: builder.DB(name),
    23  	}
    24  }
    25  
    26  type Database struct {
    27  	*builder.Database
    28  }
    29  
    30  func (database *Database) SetName(name string) {
    31  	database.Database.Name = name
    32  }
    33  
    34  func (database *Database) Register(model Model) *builder.Table {
    35  	database.mustStructType(model)
    36  	rv := reflect.Indirect(reflect.ValueOf(model))
    37  	table := builder.T(database.Database, model.TableName())
    38  	ScanDefToTable(rv, table)
    39  	database.Database.Register(table)
    40  	return table
    41  }
    42  
    43  func (database Database) T(model Model) *builder.Table {
    44  	database.mustStructType(model)
    45  	return database.Database.Table(model.TableName())
    46  }
    47  
    48  func (database Database) mustStructType(model Model) {
    49  	tpe := reflect.TypeOf(model)
    50  	if tpe.Kind() != reflect.Ptr {
    51  		panic(fmt.Errorf("model %s must be a pointer", tpe.Name()))
    52  	}
    53  	tpe = tpe.Elem()
    54  	if tpe.Kind() != reflect.Struct {
    55  		panic(fmt.Errorf("model %s must be a struct", tpe.Name()))
    56  	}
    57  }
    58  
    59  func (database *Database) Insert(model Model) *builder.StmtInsert {
    60  	table := database.T(model)
    61  
    62  	fieldValues := FieldValuesFromStructByNonZero(model)
    63  
    64  	if autoIncrementCol := table.AutoIncrement(); autoIncrementCol != nil {
    65  		delete(fieldValues, autoIncrementCol.FieldName)
    66  	}
    67  
    68  	cols, vals := table.ColumnsAndValuesByFieldValues(fieldValues)
    69  
    70  	return table.Insert().Columns(cols).Values(vals...)
    71  }
    72  
    73  func (database *Database) Update(model Model, zeroFields ...string) *builder.StmtUpdate {
    74  	table := database.T(model)
    75  
    76  	fieldValues := FieldValuesFromStructByNonZero(model, zeroFields...)
    77  
    78  	if autoIncrementCol := table.AutoIncrement(); autoIncrementCol != nil {
    79  		delete(fieldValues, autoIncrementCol.FieldName)
    80  	}
    81  
    82  	return table.Update().Set(table.AssignsByFieldValues(fieldValues)...)
    83  }
    84  
    85  func (database *Database) MustMigrateTo(db *DB, dryRun bool) {
    86  	if err := database.MigrateTo(db, dryRun); err != nil {
    87  		logrus.Panic(err)
    88  	}
    89  }
    90  
    91  func (database *Database) MigrateTo(db *DB, dryRun bool) error {
    92  	database.Register(&SqlMetaEnum{})
    93  
    94  	currentDatabase := DBFromInformationSchema(db, database.Name, database.Tables.TableNames()...)
    95  
    96  	if !dryRun {
    97  		logrus.Debugf("=================== migrating database `%s` ====================", database.Name)
    98  		defer logrus.Debugf("=================== migrated database `%s` ====================", database.Name)
    99  
   100  		if currentDatabase == nil {
   101  			currentDatabase = &Database{
   102  				Database: builder.DB(database.Name),
   103  			}
   104  			if err := db.Do(currentDatabase.Create(true)).Err(); err != nil {
   105  				return err
   106  			}
   107  		}
   108  
   109  		for name, table := range database.Tables {
   110  			currentTable := currentDatabase.Table(name)
   111  			if currentTable == nil {
   112  				if err := db.Do(table.Create(true)).Err(); err != nil {
   113  					return err
   114  				}
   115  				continue
   116  			}
   117  
   118  			stmt := currentTable.Diff(table)
   119  			if stmt != nil {
   120  				if err := db.Do(stmt).Err(); err != nil {
   121  					return err
   122  				}
   123  				continue
   124  			}
   125  		}
   126  
   127  		if err := database.SyncEnum(db); err != nil {
   128  			return err
   129  		}
   130  
   131  		return nil
   132  	}
   133  
   134  	if currentDatabase == nil {
   135  		currentDatabase = &Database{
   136  			Database: builder.DB(database.Name),
   137  		}
   138  
   139  		fmt.Printf("=================== need to migrate database `%s` ====================\n", database.Name)
   140  		fmt.Println(currentDatabase.Create(true).Query)
   141  		fmt.Printf("=================== need to migrate database `%s` ====================\n", database.Name)
   142  	}
   143  
   144  	for name, table := range database.Tables {
   145  		currentTable := currentDatabase.Table(name)
   146  		if currentTable == nil {
   147  			fmt.Println(table.Create(true).Query)
   148  			continue
   149  		}
   150  
   151  		stmt := currentTable.Diff(table)
   152  		if stmt != nil {
   153  			fmt.Println(stmt.Query)
   154  			continue
   155  		}
   156  	}
   157  
   158  	if err := database.SyncEnum(db); err != nil {
   159  		return err
   160  	}
   161  
   162  	return nil
   163  }