github.com/marinho/drone@v0.2.1-0.20140504195434-d3ba962e89a7/pkg/database/migrate/mysql.go (about)

     1  package migrate
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"strings"
     7  )
     8  
     9  type mysqlDriver struct {
    10  	Tx *sql.Tx
    11  }
    12  
    13  func MySQL(tx *sql.Tx) *MigrationDriver {
    14  	return &MigrationDriver{
    15  		Tx:        tx,
    16  		Operation: &mysqlDriver{Tx: tx},
    17  		T: &columnType{
    18  			AttrMap: map[int]string{AUTOINCREMENT: "AUTO_INCREMENT"},
    19  		},
    20  	}
    21  }
    22  
    23  func (m *mysqlDriver) CreateTable(tableName string, args []string) (sql.Result, error) {
    24  	return m.Tx.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s) ROW_FORMAT=DYNAMIC",
    25  		tableName, strings.Join(args, ", ")))
    26  }
    27  
    28  func (m *mysqlDriver) RenameTable(tableName, newName string) (sql.Result, error) {
    29  	return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME TO %s", tableName, newName))
    30  }
    31  
    32  func (m *mysqlDriver) DropTable(tableName string) (sql.Result, error) {
    33  	return m.Tx.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName))
    34  }
    35  
    36  func (m *mysqlDriver) AddColumn(tableName, columnSpec string) (sql.Result, error) {
    37  	return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN (%s)", tableName, columnSpec))
    38  }
    39  
    40  func (m *mysqlDriver) ChangeColumn(tableName, columnName, newSpecs string) (sql.Result, error) {
    41  	return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s MODIFY %s %s", tableName, columnName, newSpecs))
    42  }
    43  
    44  func (m *mysqlDriver) DropColumns(tableName string, columnsToDrop ...string) (sql.Result, error) {
    45  	if len(columnsToDrop) == 0 {
    46  		return nil, fmt.Errorf("No columns to drop.")
    47  	}
    48  	for k, v := range columnsToDrop {
    49  		columnsToDrop[k] = fmt.Sprintf("DROP %s", v)
    50  	}
    51  	return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(columnsToDrop, ", ")))
    52  }
    53  
    54  func (m *mysqlDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) {
    55  	var columns []string
    56  
    57  	tableSQL, err := m.getTableDefinition(tableName)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	columns, err = fetchColumns(tableSQL)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	var colspec []string
    68  	for k, v := range columnChanges {
    69  		for _, col := range columns {
    70  			col = strings.Trim(col, " \n")
    71  			cols := strings.SplitN(col, " ", 2)
    72  			if quote(k) == cols[0] {
    73  				colspec = append(colspec, fmt.Sprintf("CHANGE %s %s %s", k, v, cols[1]))
    74  				break
    75  			}
    76  		}
    77  	}
    78  
    79  	return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(colspec, ", ")))
    80  }
    81  
    82  func (m *mysqlDriver) AddIndex(tableName string, columns []string, flags ...string) (sql.Result, error) {
    83  	flag := ""
    84  	if len(flags) > 0 {
    85  		switch strings.ToUpper(flags[0]) {
    86  		case "UNIQUE":
    87  			fallthrough
    88  		case "FULLTEXT":
    89  			fallthrough
    90  		case "SPATIAL":
    91  			flag = flags[0]
    92  		}
    93  	}
    94  	return m.Tx.Exec(fmt.Sprintf("CREATE %s INDEX %s ON %s (%s)", flag,
    95  		indexName(tableName, columns), tableName, strings.Join(columns, ", ")))
    96  }
    97  
    98  func (m *mysqlDriver) DropIndex(tableName string, columns []string) (sql.Result, error) {
    99  	return m.Tx.Exec(fmt.Sprintf("DROP INDEX %s on %s", indexName(tableName, columns), tableName))
   100  }
   101  
   102  func (m *mysqlDriver) getTableDefinition(tableName string) (string, error) {
   103  	var name, def string
   104  	st := fmt.Sprintf("SHOW CREATE TABLE %s", tableName)
   105  	if err := m.Tx.QueryRow(st).Scan(&name, &def); err != nil {
   106  		return "", err
   107  	}
   108  	return def, nil
   109  }
   110  
   111  func quote(name string) string {
   112  	return fmt.Sprintf("`%s`", name)
   113  }