github.com/marinho/drone@v0.2.1-0.20140504195434-d3ba962e89a7/pkg/database/migrate/sqlite.go (about)

     1  package migrate
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"strings"
     7  )
     8  
     9  type sqliteDriver struct {
    10  	Tx *sql.Tx
    11  }
    12  
    13  func SQLite(tx *sql.Tx) *MigrationDriver {
    14  	return &MigrationDriver{
    15  		Tx:        tx,
    16  		Operation: &sqliteDriver{Tx: tx},
    17  		T:         &columnType{},
    18  	}
    19  }
    20  
    21  func (s *sqliteDriver) CreateTable(tableName string, args []string) (sql.Result, error) {
    22  	return s.Tx.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", tableName, strings.Join(args, ", ")))
    23  }
    24  
    25  func (s *sqliteDriver) RenameTable(tableName, newName string) (sql.Result, error) {
    26  	return s.Tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME TO %s", tableName, newName))
    27  }
    28  
    29  func (s *sqliteDriver) DropTable(tableName string) (sql.Result, error) {
    30  	return s.Tx.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName))
    31  }
    32  
    33  func (s *sqliteDriver) AddColumn(tableName, columnSpec string) (sql.Result, error) {
    34  	return s.Tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", tableName, columnSpec))
    35  }
    36  
    37  func (s *sqliteDriver) ChangeColumn(tableName, columnName, newType string) (sql.Result, error) {
    38  	var result sql.Result
    39  	var err error
    40  
    41  	tableSQL, err := s.getTableDefinition(tableName)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  
    46  	columns, err := fetchColumns(tableSQL)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	columnNames := selectName(columns)
    52  
    53  	for k, column := range columnNames {
    54  		if columnName == column {
    55  			columns[k] = fmt.Sprintf("%s %s", columnName, newType)
    56  			break
    57  		}
    58  	}
    59  
    60  	indices, err := s.getIndexDefinition(tableName)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	proxy := proxyName(tableName)
    66  	if result, err = s.RenameTable(tableName, proxy); err != nil {
    67  		return nil, err
    68  	}
    69  
    70  	if result, err = s.CreateTable(tableName, columns); err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	// Migrate data
    75  	if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName,
    76  		strings.Join(columnNames, ", "), proxy)); err != nil {
    77  		return result, err
    78  	}
    79  
    80  	// Clean up proxy table
    81  	if result, err = s.DropTable(proxy); err != nil {
    82  		return result, err
    83  	}
    84  
    85  	for _, idx := range indices {
    86  		if result, err = s.Tx.Exec(idx); err != nil {
    87  			return result, err
    88  		}
    89  	}
    90  	return result, err
    91  
    92  }
    93  
    94  func (s *sqliteDriver) DropColumns(tableName string, columnsToDrop ...string) (sql.Result, error) {
    95  	var err error
    96  	var result sql.Result
    97  
    98  	if len(columnsToDrop) == 0 {
    99  		return nil, fmt.Errorf("No columns to drop.")
   100  	}
   101  
   102  	tableSQL, err := s.getTableDefinition(tableName)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	columns, err := fetchColumns(tableSQL)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	columnNames := selectName(columns)
   113  
   114  	var preparedColumns []string
   115  	for k, column := range columnNames {
   116  		listed := false
   117  		for _, dropped := range columnsToDrop {
   118  			if column == dropped {
   119  				listed = true
   120  				break
   121  			}
   122  		}
   123  		if !listed {
   124  			preparedColumns = append(preparedColumns, columns[k])
   125  		}
   126  	}
   127  
   128  	if len(preparedColumns) == 0 {
   129  		return nil, fmt.Errorf("No columns match, drops nothing.")
   130  	}
   131  
   132  	// fetch indices for this table
   133  	oldSQLIndices, err := s.getIndexDefinition(tableName)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  
   138  	var oldIdxColumns [][]string
   139  	for _, idx := range oldSQLIndices {
   140  		idxCols, err := fetchColumns(idx)
   141  		if err != nil {
   142  			return nil, err
   143  		}
   144  		oldIdxColumns = append(oldIdxColumns, idxCols)
   145  	}
   146  
   147  	var indices []string
   148  	for k, idx := range oldSQLIndices {
   149  		listed := false
   150  	OIdxLoop:
   151  		for _, oidx := range oldIdxColumns[k] {
   152  			for _, cols := range columnsToDrop {
   153  				if oidx == cols {
   154  					listed = true
   155  					break OIdxLoop
   156  				}
   157  			}
   158  		}
   159  		if !listed {
   160  			indices = append(indices, idx)
   161  		}
   162  	}
   163  
   164  	// Rename old table, here's our proxy
   165  	proxy := proxyName(tableName)
   166  	if result, err := s.RenameTable(tableName, proxy); err != nil {
   167  		return result, err
   168  	}
   169  
   170  	// Recreate table with dropped columns omitted
   171  	if result, err = s.CreateTable(tableName, preparedColumns); err != nil {
   172  		return result, err
   173  	}
   174  
   175  	// Move data from old table
   176  	if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName,
   177  		strings.Join(selectName(preparedColumns), ", "), proxy)); err != nil {
   178  		return result, err
   179  	}
   180  
   181  	// Clean up proxy table
   182  	if result, err = s.DropTable(proxy); err != nil {
   183  		return result, err
   184  	}
   185  
   186  	// Recreate Indices
   187  	for _, idx := range indices {
   188  		if result, err = s.Tx.Exec(idx); err != nil {
   189  			return result, err
   190  		}
   191  	}
   192  	return result, err
   193  }
   194  
   195  func (s *sqliteDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) {
   196  	var err error
   197  	var result sql.Result
   198  
   199  	tableSQL, err := s.getTableDefinition(tableName)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  
   204  	columns, err := fetchColumns(tableSQL)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  
   209  	// We need a list of columns name to migrate data to the new table
   210  	var oldColumnsName = selectName(columns)
   211  
   212  	// newColumns will be used to create the new table
   213  	var newColumns []string
   214  
   215  	for k, column := range oldColumnsName {
   216  		added := false
   217  		for Old, New := range columnChanges {
   218  			if column == Old {
   219  				columnToAdd := strings.Replace(columns[k], Old, New, 1)
   220  				newColumns = append(newColumns, columnToAdd)
   221  				added = true
   222  				break
   223  			}
   224  		}
   225  		if !added {
   226  			newColumns = append(newColumns, columns[k])
   227  		}
   228  	}
   229  
   230  	// fetch indices for this table
   231  	oldSQLIndices, err := s.getIndexDefinition(tableName)
   232  	if err != nil {
   233  		return nil, err
   234  	}
   235  
   236  	var idxColumns [][]string
   237  	for _, idx := range oldSQLIndices {
   238  		idxCols, err := fetchColumns(idx)
   239  		if err != nil {
   240  			return nil, err
   241  		}
   242  		idxColumns = append(idxColumns, idxCols)
   243  	}
   244  
   245  	var indices []string
   246  	for k, idx := range oldSQLIndices {
   247  		added := false
   248  	IdcLoop:
   249  		for _, oldIdx := range idxColumns[k] {
   250  			for Old, New := range columnChanges {
   251  				if oldIdx == Old {
   252  					indx := strings.Replace(idx, Old, New, 2)
   253  					indices = append(indices, indx)
   254  					added = true
   255  					break IdcLoop
   256  				}
   257  			}
   258  		}
   259  		if !added {
   260  			indices = append(indices, idx)
   261  		}
   262  	}
   263  
   264  	// Rename current table
   265  	proxy := proxyName(tableName)
   266  	if result, err := s.RenameTable(tableName, proxy); err != nil {
   267  		return result, err
   268  	}
   269  
   270  	// Create new table with the new columns
   271  	if result, err = s.CreateTable(tableName, newColumns); err != nil {
   272  		return result, err
   273  	}
   274  
   275  	// Migrate data
   276  	if result, err = s.Tx.Exec(fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", tableName,
   277  		strings.Join(oldColumnsName, ", "), proxy)); err != nil {
   278  		return result, err
   279  	}
   280  
   281  	// Clean up proxy table
   282  	if result, err = s.DropTable(proxy); err != nil {
   283  		return result, err
   284  	}
   285  
   286  	for _, idx := range indices {
   287  		if result, err = s.Tx.Exec(idx); err != nil {
   288  			return result, err
   289  		}
   290  	}
   291  	return result, err
   292  }
   293  
   294  func (s *sqliteDriver) AddIndex(tableName string, columns []string, flags ...string) (sql.Result, error) {
   295  	flag := ""
   296  	if len(flags) > 0 {
   297  		if strings.ToUpper(flags[0]) == "UNIQUE" {
   298  			flag = flags[0]
   299  		}
   300  	}
   301  	return s.Tx.Exec(fmt.Sprintf("CREATE %s INDEX %s ON %s (%s)", flag, indexName(tableName, columns),
   302  		tableName, strings.Join(columns, ", ")))
   303  }
   304  
   305  func (s *sqliteDriver) DropIndex(tableName string, columns []string) (sql.Result, error) {
   306  	return s.Tx.Exec(fmt.Sprintf("DROP INDEX %s", indexName(tableName, columns)))
   307  }
   308  
   309  func (s *sqliteDriver) getTableDefinition(tableName string) (string, error) {
   310  	var sql string
   311  	query := `SELECT sql FROM sqlite_master WHERE type='table' and name=?`
   312  	err := s.Tx.QueryRow(query, tableName).Scan(&sql)
   313  	if err != nil {
   314  		return "", err
   315  	}
   316  	return sql, nil
   317  }
   318  
   319  func (s *sqliteDriver) getIndexDefinition(tableName string) ([]string, error) {
   320  	var sqls []string
   321  
   322  	query := `SELECT sql FROM sqlite_master WHERE type='index' and tbl_name=?`
   323  	rows, err := s.Tx.Query(query, tableName)
   324  	if err != nil {
   325  		return sqls, err
   326  	}
   327  
   328  	for rows.Next() {
   329  		var sql sql.NullString
   330  		if err := rows.Scan(&sql); err != nil {
   331  			return sqls, err
   332  		}
   333  		if sql.Valid {
   334  			sqls = append(sqls, sql.String)
   335  		}
   336  	}
   337  
   338  	if err := rows.Err(); err != nil {
   339  		return sqls, err
   340  	}
   341  
   342  	return sqls, nil
   343  }