github.com/RevenueMonster/sqlike@v1.0.6/sql/dialect/mysql/table.go (about)

     1  package mysql
     2  
     3  import (
     4  	"reflect"
     5  	"strings"
     6  
     7  	"github.com/RevenueMonster/sqlike/reflext"
     8  	"github.com/RevenueMonster/sqlike/sql/driver"
     9  	sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt"
    10  	"github.com/RevenueMonster/sqlike/sql/util"
    11  	"github.com/RevenueMonster/sqlike/sqlike/columns"
    12  	"github.com/RevenueMonster/sqlike/sqlike/indexes"
    13  )
    14  
    15  // HasPrimaryKey :
    16  func (ms MySQL) HasPrimaryKey(stmt sqlstmt.Stmt, db, table string) {
    17  	stmt.WriteString("SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS ")
    18  	stmt.WriteString("WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND CONSTRAINT_TYPE = 'PRIMARY KEY'")
    19  	stmt.WriteByte(';')
    20  	stmt.AppendArgs(db, table)
    21  }
    22  
    23  // RenameTable :
    24  func (ms MySQL) RenameTable(stmt sqlstmt.Stmt, db, oldName, newName string) {
    25  	stmt.WriteString("RENAME TABLE ")
    26  	stmt.WriteString(ms.TableName(db, oldName))
    27  	stmt.WriteString(" TO ")
    28  	stmt.WriteString(ms.TableName(db, newName))
    29  	stmt.WriteByte(';')
    30  }
    31  
    32  // DropTable :
    33  func (ms MySQL) DropTable(stmt sqlstmt.Stmt, db, table string, exists bool) {
    34  	stmt.WriteString("DROP TABLE")
    35  	if exists {
    36  		stmt.WriteString(" IF EXISTS")
    37  	}
    38  	stmt.WriteByte(' ')
    39  	stmt.WriteString(ms.TableName(db, table) + ";")
    40  }
    41  
    42  // TruncateTable :
    43  func (ms MySQL) TruncateTable(stmt sqlstmt.Stmt, db, table string) {
    44  	stmt.WriteString("TRUNCATE TABLE " + ms.TableName(db, table) + ";")
    45  }
    46  
    47  // HasTable :
    48  func (ms MySQL) HasTable(stmt sqlstmt.Stmt, dbName, table string) {
    49  	stmt.WriteString(`SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?;`)
    50  	stmt.AppendArgs(dbName, table)
    51  }
    52  
    53  // CreateTable :
    54  func (ms MySQL) CreateTable(stmt sqlstmt.Stmt, db, table, pk string, info driver.Info, fields []reflext.StructFielder) (err error) {
    55  	var (
    56  		col     columns.Column
    57  		pkk     reflext.StructFielder
    58  		k1, k2  string
    59  		virtual bool
    60  		stored  bool
    61  	)
    62  
    63  	stmt.WriteString("CREATE TABLE " + ms.TableName(db, table) + " ")
    64  	stmt.WriteByte('(')
    65  
    66  	// Main columns :
    67  	for i, sf := range fields {
    68  		if i > 0 {
    69  			stmt.WriteByte(',')
    70  		}
    71  
    72  		col, err = ms.schema.GetColumn(info, sf)
    73  		if err != nil {
    74  			return
    75  		}
    76  
    77  		tag := sf.Tag()
    78  		// allow primary_key tag to override
    79  		if _, ok := tag.LookUp("primary_key"); ok {
    80  			pkk = sf
    81  		} else if _, ok := tag.LookUp("auto_increment"); ok {
    82  			pkk = sf
    83  		} else if sf.Name() == pk && pkk == nil {
    84  			pkk = sf
    85  		}
    86  
    87  		idx := indexes.Index{Columns: indexes.Columns(sf.Name())}
    88  		if _, ok := tag.LookUp("unique_index"); ok {
    89  			stmt.WriteString("UNIQUE INDEX " + idx.GetName() + " (" + ms.Quote(sf.Name()) + ")")
    90  			stmt.WriteByte(',')
    91  		}
    92  
    93  		ms.buildSchemaByColumn(stmt, col)
    94  
    95  		if v, ok := tag.LookUp("comment"); ok {
    96  			if len(v) > 60 {
    97  				panic("maximum length of comment is 60 characters")
    98  			}
    99  			stmt.WriteString(" COMMENT '" + v + "'")
   100  		}
   101  
   102  		// check generated columns
   103  		t := reflext.Deref(sf.Type())
   104  		if t.Kind() != reflect.Struct {
   105  			continue
   106  		}
   107  
   108  		children := sf.Children()
   109  		for len(children) > 0 {
   110  			child := children[0]
   111  			tg := child.Tag()
   112  			k1, virtual = tg.LookUp("virtual_column")
   113  			k2, stored = tg.LookUp("stored_column")
   114  			if virtual || stored {
   115  				stmt.WriteByte(',')
   116  				col, err = ms.schema.GetColumn(info, child)
   117  				if err != nil {
   118  					return
   119  				}
   120  
   121  				name := col.Name
   122  				if virtual && k1 != "" {
   123  					name = k1
   124  				}
   125  				if stored && k2 != "" {
   126  					name = k2
   127  				}
   128  
   129  				stmt.WriteString(ms.Quote(name))
   130  				stmt.WriteString(" " + col.Type)
   131  				path := strings.TrimLeft(strings.TrimPrefix(child.Name(), sf.Name()), ".")
   132  				stmt.WriteString(" AS ")
   133  				stmt.WriteString("(" + ms.Quote(sf.Name()) + "->>'$." + path + "')")
   134  				if stored {
   135  					stmt.WriteString(" STORED")
   136  				}
   137  				if !col.Nullable {
   138  					stmt.WriteString(" NOT NULL")
   139  				}
   140  			}
   141  			children = children[1:]
   142  			children = append(children, child.Children()...)
   143  		}
   144  
   145  	}
   146  	if pkk != nil {
   147  		stmt.WriteByte(',')
   148  		stmt.WriteString("PRIMARY KEY (" + ms.Quote(pkk.Name()) + ")")
   149  	}
   150  	stmt.WriteByte(')')
   151  	stmt.WriteString(" ENGINE=INNODB")
   152  	code := string(info.Charset())
   153  	if code == "" {
   154  		stmt.WriteString(" CHARACTER SET utf8mb4")
   155  		stmt.WriteString(" COLLATE utf8mb4_unicode_ci")
   156  	} else {
   157  		stmt.WriteString(" CHARACTER SET " + code)
   158  		if info.Collate() != "" {
   159  			stmt.WriteString(" COLLATE " + info.Collate())
   160  		}
   161  	}
   162  	stmt.WriteByte(';')
   163  	return
   164  }
   165  
   166  // AlterTable :
   167  func (ms *MySQL) AlterTable(stmt sqlstmt.Stmt, db, table, pk string, hasPk bool, info driver.Info, fields []reflext.StructFielder, cols util.StringSlice, idxs util.StringSlice, unsafe bool) (err error) {
   168  	var (
   169  		col     columns.Column
   170  		pkk     reflext.StructFielder
   171  		idx     int
   172  		k1, k2  string
   173  		virtual bool
   174  		stored  bool
   175  	)
   176  
   177  	suffix := "FIRST"
   178  	stmt.WriteString("ALTER TABLE " + ms.TableName(db, table) + " ")
   179  
   180  	for i, sf := range fields {
   181  		if i > 0 {
   182  			stmt.WriteByte(',')
   183  		}
   184  
   185  		action := "ADD"
   186  		idx = cols.IndexOf(sf.Name())
   187  		if idx > -1 {
   188  			action = "MODIFY"
   189  			cols.Splice(idx)
   190  		}
   191  		if !hasPk {
   192  			// allow primary_key tag to override
   193  			if _, ok := sf.Tag().LookUp("primary_key"); ok {
   194  				pkk = sf
   195  			}
   196  			if sf.Name() == pk && pkk == nil {
   197  				pkk = sf
   198  			}
   199  		}
   200  
   201  		tag := sf.Tag()
   202  		_, ok1 := tag.LookUp("unique_index")
   203  		_, ok2 := tag.LookUp("auto_increment")
   204  		if ok1 || ok2 {
   205  			idx := indexes.Index{Columns: indexes.Columns(sf.Name())}
   206  			if idxs.IndexOf(idx.GetName()) < 0 {
   207  				stmt.WriteString("ADD")
   208  				stmt.WriteString(" UNIQUE INDEX " + idx.GetName() + " (" + ms.Quote(sf.Name()) + ")")
   209  				stmt.WriteByte(',')
   210  			}
   211  		}
   212  		stmt.WriteString(action + " ")
   213  		col, err = ms.schema.GetColumn(info, sf)
   214  		if err != nil {
   215  			return
   216  		}
   217  		ms.buildSchemaByColumn(stmt, col)
   218  
   219  		if v, ok := sf.Tag().LookUp("comment"); ok {
   220  			if len(v) > 60 {
   221  				panic("maximum length of comment is 60 characters")
   222  			}
   223  			stmt.WriteString(" COMMENT '" + v + "'")
   224  		}
   225  
   226  		stmt.WriteString(" " + suffix)
   227  		suffix = "AFTER " + ms.Quote(sf.Name())
   228  
   229  		// check generated columns
   230  		t := reflext.Deref(sf.Type())
   231  		if t.Kind() != reflect.Struct {
   232  			continue
   233  		}
   234  
   235  		children := sf.Children()
   236  		for len(children) > 0 {
   237  			child := children[0]
   238  			tg := child.Tag()
   239  			k1, virtual = tg.LookUp("virtual_column")
   240  			k2, stored = tg.LookUp("stored_column")
   241  			if virtual || stored {
   242  				stmt.WriteByte(',')
   243  				col, err = ms.schema.GetColumn(info, child)
   244  				if err != nil {
   245  					return
   246  				}
   247  
   248  				name := col.Name
   249  				if virtual && k1 != "" {
   250  					name = k1
   251  				}
   252  				if stored && k2 != "" {
   253  					name = k2
   254  				}
   255  
   256  				action = "ADD"
   257  				idx = cols.IndexOf(name)
   258  				if idx > -1 {
   259  					action = "MODIFY"
   260  					cols.Splice(idx)
   261  				}
   262  
   263  				stmt.WriteString(action + " ")
   264  				stmt.WriteString(ms.Quote(name))
   265  				stmt.WriteString(" " + col.Type)
   266  				path := strings.TrimLeft(strings.TrimPrefix(child.Name(), sf.Name()), ".")
   267  				stmt.WriteString(" AS ")
   268  				stmt.WriteString("(" + ms.Quote(sf.Name()) + "->>'$." + path + "')")
   269  				if stored {
   270  					stmt.WriteString(" STORED")
   271  				}
   272  				if !col.Nullable {
   273  					stmt.WriteString(" NOT NULL")
   274  				}
   275  				stmt.WriteString(" " + suffix)
   276  				suffix = "AFTER " + ms.Quote(name)
   277  			}
   278  			children = children[1:]
   279  			children = append(children, child.Children()...)
   280  		}
   281  
   282  	}
   283  
   284  	if pkk != nil {
   285  		stmt.WriteByte(',')
   286  		stmt.WriteString("ADD PRIMARY KEY (" + ms.Quote(pkk.Name()) + ")")
   287  	}
   288  
   289  	if unsafe {
   290  		for _, col := range cols {
   291  			stmt.WriteByte(',')
   292  			stmt.WriteString("DROP COLUMN ")
   293  			stmt.WriteString(ms.Quote(col))
   294  		}
   295  	}
   296  
   297  	// TODO: character set
   298  	stmt.WriteByte(',')
   299  	stmt.WriteString(`CHARACTER SET utf8mb4`)
   300  	stmt.WriteString(` COLLATE utf8mb4_unicode_ci`)
   301  	stmt.WriteByte(';')
   302  	return
   303  }