github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/sqlx/builder/obj_table.go (about)

     1  package builder
     2  
     3  import (
     4  	"fmt"
     5  	"regexp"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/artisanhe/tools/env"
    10  )
    11  
    12  type TableDef interface {
    13  	IsValidDef() bool
    14  	Def() *Expression
    15  }
    16  
    17  func T(db *Database, tableName string) *Table {
    18  	return &Table{
    19  		DB:   db,
    20  		Name: tableName,
    21  	}
    22  }
    23  
    24  type Table struct {
    25  	DB   *Database
    26  	Name string
    27  	Columns
    28  	Keys
    29  	Engine  string
    30  	Charset string
    31  }
    32  
    33  func (t Table) Define(defs ...TableDef) *Table {
    34  	for _, def := range defs {
    35  		if def.IsValidDef() {
    36  			switch def.(type) {
    37  			case *Column:
    38  				t.Columns.Add(def.(*Column))
    39  			case *Key:
    40  				t.Keys.Add(def.(*Key))
    41  			}
    42  		}
    43  	}
    44  	return &t
    45  }
    46  
    47  var (
    48  	fieldNamePlaceholder = regexp.MustCompile("#[A-Z][A-Za-z0-9_]+")
    49  )
    50  
    51  func (t *Table) Ex(query string, args ...interface{}) *Expression {
    52  	finalQuery := fieldNamePlaceholder.ReplaceAllStringFunc(query, func(i string) string {
    53  		fieldName := strings.TrimLeft(i, "#")
    54  		if col := t.F(fieldName); col != nil {
    55  			return col.String()
    56  		}
    57  		return i
    58  	})
    59  	return Expr(finalQuery, args...)
    60  }
    61  
    62  func (t *Table) Cond(query string, args ...interface{}) *Condition {
    63  	return (*Condition)(t.Ex(query, args...))
    64  }
    65  
    66  type FieldValues map[string]interface{}
    67  
    68  func (t *Table) ColumnsAndValuesByFieldValues(fieldValues FieldValues) (columns Columns, args []interface{}) {
    69  	fieldNames := make([]string, 0)
    70  	for fieldName := range fieldValues {
    71  		fieldNames = append(fieldNames, fieldName)
    72  	}
    73  
    74  	sort.Strings(fieldNames)
    75  
    76  	for _, fieldName := range fieldNames {
    77  		if col := t.F(fieldName); col != nil {
    78  			columns.Add(col)
    79  			args = append(args, fieldValues[fieldName])
    80  		}
    81  	}
    82  	return
    83  }
    84  
    85  func (t *Table) AssignsByFieldValues(fieldValues FieldValues) (assignments Assignments) {
    86  	for fieldName, value := range fieldValues {
    87  		col := t.F(fieldName)
    88  		if col != nil {
    89  			assignments = append(assignments, col.By(value))
    90  		}
    91  	}
    92  	return
    93  }
    94  
    95  func (t *Table) String() string {
    96  	return quote(t.Name)
    97  }
    98  
    99  func (t *Table) FullName() string {
   100  	return t.DB.String() + "." + t.String()
   101  }
   102  
   103  func (t *Table) Insert() *StmtInsert {
   104  	return Insert(t)
   105  }
   106  
   107  func (t *Table) Delete() *StmtDelete {
   108  	return Delete(t)
   109  }
   110  
   111  func (t *Table) Select() *StmtSelect {
   112  	return SelectFrom(t)
   113  }
   114  
   115  func (t *Table) Update() *StmtUpdate {
   116  	return Update(t)
   117  }
   118  
   119  func (t *Table) Drop() *Stmt {
   120  	return (*Stmt)(Expr(fmt.Sprintf("DROP TABLE %s", t.FullName())))
   121  }
   122  
   123  func (t *Table) Truncate() *Stmt {
   124  	return (*Stmt)(Expr(fmt.Sprintf("TRUNCATE TABLE %s", t.FullName())))
   125  }
   126  
   127  func (t *Table) Diff(table *Table) *Stmt {
   128  	colsDiffResult := t.Columns.Diff(table.Columns)
   129  	keysDiffResult := t.Keys.Diff(table.Keys)
   130  
   131  	colsChanged := colsDiffResult.IsChanged()
   132  	indexesChanged := keysDiffResult.IsChanged()
   133  
   134  	if !colsChanged && !indexesChanged {
   135  		return nil
   136  	}
   137  	expr := Expr(fmt.Sprintf(`ALTER TABLE %s `, t.FullName()))
   138  
   139  	joiner := ""
   140  
   141  	if colsChanged {
   142  		if Configuration.DropColumnWhenMigration || env.IsOnline() {
   143  			colsDiffResult.colsForDelete.Range(func(col *Column, idx int) {
   144  				expr = expr.ConcatBy(joiner, col.Drop())
   145  				joiner = ", "
   146  			})
   147  		}
   148  		colsDiffResult.colsForUpdate.Range(func(col *Column, idx int) {
   149  			expr = expr.ConcatBy(joiner, col.Modify())
   150  			joiner = ", "
   151  		})
   152  		colsDiffResult.colsForAdd.Range(func(col *Column, idx int) {
   153  			expr = expr.ConcatBy(joiner, col.Add())
   154  			joiner = ", "
   155  		})
   156  	}
   157  
   158  	if indexesChanged {
   159  		keysDiffResult.keysForDelete.Range(func(key *Key, idx int) {
   160  			expr = expr.ConcatBy(joiner, key.Drop())
   161  			joiner = ", "
   162  		})
   163  		keysDiffResult.keysForUpdate.Range(func(key *Key, idx int) {
   164  			expr = expr.ConcatBy(joiner, key.Drop())
   165  			joiner = ", "
   166  			expr = expr.ConcatBy(joiner, key.Add())
   167  		})
   168  		keysDiffResult.keysForAdd.Range(func(key *Key, idx int) {
   169  			expr = expr.ConcatBy(joiner, key.Add())
   170  			joiner = ", "
   171  		})
   172  	}
   173  
   174  	return (*Stmt)(expr)
   175  }
   176  
   177  func (t *Table) Create(ifNotExists bool) *Stmt {
   178  	expr := Expr("CREATE TABLE")
   179  	if ifNotExists {
   180  		expr = expr.ConcatBy(" ", Expr("IF NOT EXISTS"))
   181  	}
   182  	expr.Query = expr.Query + fmt.Sprintf(" %s (", t.FullName())
   183  
   184  	if !t.Columns.IsEmpty() {
   185  		isFirstCol := true
   186  
   187  		t.Columns.Range(func(col *Column, idx int) {
   188  			joiner := ", "
   189  			if isFirstCol {
   190  				joiner = ""
   191  			}
   192  			def := col.Def()
   193  			if def != nil {
   194  				isFirstCol = false
   195  				expr = expr.ConcatBy(joiner, col.Def())
   196  			}
   197  		})
   198  
   199  		t.Keys.Range(func(key *Key, idx int) {
   200  			expr = expr.ConcatBy(", ", key.Def())
   201  		})
   202  	}
   203  
   204  	engine := t.Engine
   205  	if engine == "" {
   206  		engine = "InnoDB"
   207  	}
   208  
   209  	charset := t.Charset
   210  	if charset == "" {
   211  		charset = "utf8"
   212  	}
   213  
   214  	expr.Query = fmt.Sprintf("%s) ENGINE=%s CHARSET=%s", expr.Query, engine, charset)
   215  	return (*Stmt)(expr)
   216  }
   217  
   218  type Tables map[string]*Table
   219  
   220  func (tables Tables) TableNames() (names []string) {
   221  	for name := range tables {
   222  		names = append(names, name)
   223  	}
   224  	return
   225  }
   226  
   227  func (tables Tables) Add(table *Table) {
   228  	tables[table.Name] = table
   229  }