github.com/kunlun-qilian/sqlx/v3@v3.0.0/builder/def_table.go (about)

     1  package builder
     2  
     3  import (
     4  	"bytes"
     5  	"container/list"
     6  	"context"
     7  	"fmt"
     8  	"sort"
     9  	"strings"
    10  	"text/scanner"
    11  )
    12  
    13  type TableDefinition interface {
    14  	T() *Table
    15  }
    16  
    17  func T(tableName string, tableDefinitions ...TableDefinition) *Table {
    18  	t := &Table{
    19  		Name: tableName,
    20  	}
    21  
    22  	for _, tableDef := range tableDefinitions {
    23  		switch d := tableDef.(type) {
    24  		case *Column:
    25  			t.AddCol(d)
    26  		}
    27  	}
    28  	for _, tableDef := range tableDefinitions {
    29  		switch d := tableDef.(type) {
    30  		case *Key:
    31  			t.AddKey(d)
    32  		}
    33  	}
    34  	return t
    35  }
    36  
    37  type Table struct {
    38  	Name        string
    39  	Description []string
    40  
    41  	Schema    string
    42  	ModelName string
    43  	Model     Model
    44  
    45  	Columns
    46  	Keys
    47  }
    48  
    49  func (t *Table) TableName() string {
    50  	return t.Name
    51  }
    52  
    53  func (t *Table) IsNil() bool {
    54  	return t == nil || len(t.Name) == 0
    55  }
    56  
    57  func (t Table) WithSchema(schema string) *Table {
    58  	t.Schema = schema
    59  
    60  	cols := Columns{}
    61  	t.Columns.Range(func(col *Column, idx int) {
    62  		cols.Add(col.On(&t))
    63  	})
    64  	t.Columns = cols
    65  
    66  	keys := Keys{}
    67  	t.Keys.Range(func(key *Key, idx int) {
    68  		keys.Add(key.On(&t))
    69  	})
    70  	t.Keys = keys
    71  
    72  	return &t
    73  }
    74  
    75  func (t *Table) Ex(ctx context.Context) *Ex {
    76  	if t.Schema != "" {
    77  		return Expr(t.Schema + "." + t.Name).Ex(ctx)
    78  	}
    79  	return Expr(t.Name).Ex(ctx)
    80  }
    81  
    82  func (t *Table) AddCol(d *Column) {
    83  	if d == nil {
    84  		return
    85  	}
    86  	t.Columns.Add(d.On(t))
    87  }
    88  
    89  func (t *Table) AddKey(key *Key) {
    90  	if key == nil {
    91  		return
    92  	}
    93  	t.Keys.Add(key.On(t))
    94  }
    95  
    96  func (t *Table) Expr(query string, args ...interface{}) *Ex {
    97  	if query == "" {
    98  		return nil
    99  	}
   100  
   101  	n := len(args)
   102  	e := Expr("")
   103  	e.Grow(n)
   104  
   105  	s := &scanner.Scanner{}
   106  	s.Init(bytes.NewBuffer([]byte(query)))
   107  
   108  	queryCount := 0
   109  
   110  	for tok := s.Next(); tok != scanner.EOF; tok = s.Next() {
   111  		switch tok {
   112  		case '#':
   113  			fieldNameBuf := bytes.NewBuffer(nil)
   114  
   115  			e.WriteHolder(0)
   116  
   117  			for {
   118  				tok = s.Next()
   119  
   120  				if tok == scanner.EOF {
   121  					break
   122  				}
   123  
   124  				if (tok >= 'A' && tok <= 'Z') ||
   125  					(tok >= 'a' && tok <= 'z') ||
   126  					(tok >= '0' && tok <= '9') ||
   127  					tok == '_' {
   128  
   129  					fieldNameBuf.WriteRune(tok)
   130  					continue
   131  				}
   132  
   133  				e.WriteQueryByte(byte(tok))
   134  
   135  				break
   136  			}
   137  
   138  			if fieldNameBuf.Len() == 0 {
   139  				e.AppendArgs(t)
   140  			} else {
   141  				fieldName := fieldNameBuf.String()
   142  				col := t.F(fieldNameBuf.String())
   143  				if col == nil {
   144  					panic(fmt.Errorf("missing field fieldName %s of table %s", fieldName, t.Name))
   145  				}
   146  				e.AppendArgs(col)
   147  			}
   148  		case '?':
   149  			e.WriteQueryByte(byte(tok))
   150  			if queryCount < n {
   151  				e.AppendArgs(args[queryCount])
   152  				queryCount++
   153  			}
   154  		default:
   155  			e.WriteQueryByte(byte(tok))
   156  		}
   157  	}
   158  
   159  	return e
   160  }
   161  
   162  func (t *Table) ColumnsAndValuesByFieldValues(fieldValues FieldValues) (columns *Columns, args []interface{}) {
   163  	fieldNames := make([]string, 0)
   164  	for fieldName := range fieldValues {
   165  		fieldNames = append(fieldNames, fieldName)
   166  	}
   167  
   168  	sort.Strings(fieldNames)
   169  
   170  	columns = &Columns{}
   171  
   172  	for _, fieldName := range fieldNames {
   173  		if col := t.F(fieldName); col != nil {
   174  			columns.Add(col)
   175  			args = append(args, fieldValues[fieldName])
   176  		}
   177  	}
   178  	return
   179  }
   180  
   181  func (t *Table) AssignmentsByFieldValues(fieldValues FieldValues) (assignments Assignments) {
   182  	for fieldName, value := range fieldValues {
   183  		col := t.F(fieldName)
   184  		if col != nil {
   185  			assignments = append(assignments, col.ValueBy(value))
   186  		}
   187  	}
   188  	return
   189  }
   190  
   191  func (t *Table) Diff(prevTable *Table, dialect Dialect) (exprList []SqlExpr) {
   192  	// diff columns
   193  	t.Columns.Range(func(currentCol *Column, idx int) {
   194  		if prevCol := prevTable.Col(currentCol.Name); prevCol != nil {
   195  			if currentCol != nil {
   196  				if currentCol.DeprecatedActions != nil {
   197  					renameTo := currentCol.DeprecatedActions.RenameTo
   198  					if renameTo != "" {
   199  						prevCol := prevTable.Col(renameTo)
   200  						if prevCol != nil {
   201  							exprList = append(exprList, dialect.DropColumn(prevCol))
   202  						}
   203  						targetCol := t.Col(renameTo)
   204  						if targetCol == nil {
   205  							panic(fmt.Errorf("col `%s` is not declared", renameTo))
   206  						}
   207  						exprList = append(exprList, dialect.RenameColumn(currentCol, targetCol))
   208  						prevTable.AddCol(targetCol)
   209  						return
   210  					}
   211  					exprList = append(exprList, dialect.DropColumn(currentCol))
   212  					return
   213  				}
   214  
   215  				prevColType := dialect.DataType(prevCol.ColumnType).Ex(context.Background()).Query()
   216  				currentColType := dialect.DataType(currentCol.ColumnType).Ex(context.Background()).Query()
   217  
   218  				if currentColType != prevColType {
   219  					exprList = append(exprList, dialect.ModifyColumn(currentCol, prevCol))
   220  				}
   221  				return
   222  			}
   223  			exprList = append(exprList, dialect.DropColumn(currentCol))
   224  			return
   225  		}
   226  
   227  		if currentCol.DeprecatedActions == nil {
   228  			exprList = append(exprList, dialect.AddColumn(currentCol))
   229  		}
   230  	})
   231  
   232  	// indexes
   233  	indexes := map[string]bool{}
   234  
   235  	t.Keys.Range(func(key *Key, idx int) {
   236  		if key.IsPartition() {
   237  			return
   238  		}
   239  
   240  		name := key.Name
   241  		if key.IsPrimary() {
   242  			name = dialect.PrimaryKeyName()
   243  		}
   244  		indexes[name] = true
   245  
   246  		prevKey := prevTable.Key(name)
   247  		if prevKey == nil {
   248  			exprList = append(exprList, dialect.AddIndex(key))
   249  		} else {
   250  			if !key.IsPrimary() {
   251  				indexDef := key.Def.TableExpr(key.Table).Ex(context.Background()).Query()
   252  				prevIndexDef := prevKey.Def.TableExpr(prevKey.Table).Ex(context.Background()).Query()
   253  
   254  				if !strings.EqualFold(indexDef, prevIndexDef) {
   255  					exprList = append(exprList, dialect.DropIndex(key))
   256  					exprList = append(exprList, dialect.AddIndex(key))
   257  				}
   258  			}
   259  		}
   260  	})
   261  
   262  	prevTable.Keys.Range(func(key *Key, idx int) {
   263  		if _, ok := indexes[strings.ToLower(key.Name)]; !ok {
   264  			exprList = append(exprList, dialect.DropIndex(key))
   265  		}
   266  	})
   267  
   268  	return
   269  }
   270  
   271  type Tables struct {
   272  	l      *list.List
   273  	tables map[string]*list.Element
   274  	models map[string]*list.Element
   275  }
   276  
   277  func (tables *Tables) TableNames() (names []string) {
   278  	tables.Range(func(tab *Table, idx int) {
   279  		names = append(names, tab.Name)
   280  	})
   281  	return
   282  }
   283  
   284  func (tables *Tables) Add(tabs ...*Table) {
   285  	if tables.tables == nil {
   286  		tables.tables = map[string]*list.Element{}
   287  		tables.models = map[string]*list.Element{}
   288  		tables.l = list.New()
   289  	}
   290  
   291  	for _, tab := range tabs {
   292  		if tab != nil {
   293  			if _, ok := tables.tables[tab.Name]; ok {
   294  				tables.Remove(tab.Name)
   295  			}
   296  
   297  			e := tables.l.PushBack(tab)
   298  			tables.tables[tab.Name] = e
   299  			if tab.ModelName != "" {
   300  				tables.models[tab.ModelName] = e
   301  			}
   302  		}
   303  	}
   304  }
   305  
   306  func (tables *Tables) Table(tableName string) *Table {
   307  	if tables.tables != nil {
   308  		if c, ok := tables.tables[tableName]; ok {
   309  			return c.Value.(*Table)
   310  		}
   311  	}
   312  	return nil
   313  }
   314  
   315  func (tables *Tables) Model(structName string) *Table {
   316  	if tables.models != nil {
   317  		if c, ok := tables.models[structName]; ok {
   318  			return c.Value.(*Table)
   319  		}
   320  	}
   321  	return nil
   322  }
   323  
   324  func (tables *Tables) Remove(name string) {
   325  	if tables.tables != nil {
   326  		if e, exists := tables.tables[name]; exists {
   327  			tables.l.Remove(e)
   328  			delete(tables.tables, name)
   329  		}
   330  	}
   331  }
   332  
   333  func (tables *Tables) Range(cb func(tab *Table, idx int)) {
   334  	if tables.l != nil {
   335  		i := 0
   336  		for e := tables.l.Front(); e != nil; e = e.Next() {
   337  			cb(e.Value.(*Table), i)
   338  			i++
   339  		}
   340  	}
   341  }