github.com/amacneil/dbmate@v1.16.3-0.20230225174651-ca89b10d75d7/pkg/driver/sqlite/sqlite.go (about)

     1  //go:build cgo
     2  // +build cgo
     3  
     4  package sqlite
     5  
     6  import (
     7  	"bytes"
     8  	"database/sql"
     9  	"fmt"
    10  	"io"
    11  	"net/url"
    12  	"os"
    13  	"regexp"
    14  	"strings"
    15  
    16  	"github.com/amacneil/dbmate/pkg/dbmate"
    17  	"github.com/amacneil/dbmate/pkg/dbutil"
    18  
    19  	"github.com/lib/pq"
    20  	_ "github.com/mattn/go-sqlite3" // database/sql driver
    21  )
    22  
    23  func init() {
    24  	dbmate.RegisterDriver(NewDriver, "sqlite")
    25  	dbmate.RegisterDriver(NewDriver, "sqlite3")
    26  }
    27  
    28  // Driver provides top level database functions
    29  type Driver struct {
    30  	migrationsTableName string
    31  	databaseURL         *url.URL
    32  	log                 io.Writer
    33  }
    34  
    35  // NewDriver initializes the driver
    36  func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
    37  	return &Driver{
    38  		migrationsTableName: config.MigrationsTableName,
    39  		databaseURL:         config.DatabaseURL,
    40  		log:                 config.Log,
    41  	}
    42  }
    43  
    44  // ConnectionString converts a URL into a valid connection string
    45  func ConnectionString(u *url.URL) string {
    46  	// duplicate URL and remove scheme
    47  	newURL := *u
    48  	newURL.Scheme = ""
    49  
    50  	if newURL.Opaque == "" && newURL.Path != "" {
    51  		// When the DSN is in the form "scheme:/absolute/path" or
    52  		// "scheme://absolute/path" or "scheme:///absolute/path", url.Parse
    53  		// will consider the file path as :
    54  		// - "absolute" as the hostname
    55  		// - "path" (and the rest until "?") as the URL path.
    56  		// Instead, when the DSN is in the form "scheme:", the (relative) file
    57  		// path is stored in the "Opaque" field.
    58  		// See: https://pkg.go.dev/net/url#URL
    59  		//
    60  		// While Opaque is not escaped, the URL Path is. So, if .Path contains
    61  		// the file path, we need to un-escape it, and rebuild the full path.
    62  
    63  		newURL.Opaque = "//" + newURL.Host + dbutil.MustUnescapePath(newURL.Path)
    64  		newURL.Path = ""
    65  	}
    66  
    67  	// trim duplicate leading slashes
    68  	str := regexp.MustCompile("^//+").ReplaceAllString(newURL.String(), "/")
    69  
    70  	return str
    71  }
    72  
    73  // Open creates a new database connection
    74  func (drv *Driver) Open() (*sql.DB, error) {
    75  	return sql.Open("sqlite3", ConnectionString(drv.databaseURL))
    76  }
    77  
    78  // CreateDatabase creates the specified database
    79  func (drv *Driver) CreateDatabase() error {
    80  	fmt.Fprintf(drv.log, "Creating: %s\n", ConnectionString(drv.databaseURL))
    81  
    82  	db, err := drv.Open()
    83  	if err != nil {
    84  		return err
    85  	}
    86  	defer dbutil.MustClose(db)
    87  
    88  	return db.Ping()
    89  }
    90  
    91  // DropDatabase drops the specified database (if it exists)
    92  func (drv *Driver) DropDatabase() error {
    93  	path := ConnectionString(drv.databaseURL)
    94  	fmt.Fprintf(drv.log, "Dropping: %s\n", path)
    95  
    96  	exists, err := drv.DatabaseExists()
    97  	if err != nil {
    98  		return err
    99  	}
   100  	if !exists {
   101  		return nil
   102  	}
   103  
   104  	return os.Remove(path)
   105  }
   106  
   107  func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
   108  	migrationsTable := drv.quotedMigrationsTableName()
   109  
   110  	// load applied migrations
   111  	migrations, err := dbutil.QueryColumn(db,
   112  		fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	// build schema migrations table data
   118  	var buf bytes.Buffer
   119  	buf.WriteString("-- Dbmate schema migrations\n")
   120  
   121  	if len(migrations) > 0 {
   122  		buf.WriteString(
   123  			fmt.Sprintf("INSERT INTO %s (version) VALUES\n  (", migrationsTable) +
   124  				strings.Join(migrations, "),\n  (") +
   125  				");\n")
   126  	}
   127  
   128  	return buf.Bytes(), nil
   129  }
   130  
   131  // DumpSchema returns the current database schema
   132  func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
   133  	path := ConnectionString(drv.databaseURL)
   134  	schema, err := dbutil.RunCommand("sqlite3", path, ".schema --nosys")
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  
   139  	migrations, err := drv.schemaMigrationsDump(db)
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  
   144  	schema = append(schema, migrations...)
   145  	return dbutil.TrimLeadingSQLComments(schema)
   146  }
   147  
   148  // DatabaseExists determines whether the database exists
   149  func (drv *Driver) DatabaseExists() (bool, error) {
   150  	_, err := os.Stat(ConnectionString(drv.databaseURL))
   151  	if os.IsNotExist(err) {
   152  		return false, nil
   153  	}
   154  	if err != nil {
   155  		return false, err
   156  	}
   157  
   158  	return true, nil
   159  }
   160  
   161  // MigrationsTableExists checks if the schema_migrations table exists
   162  func (drv *Driver) MigrationsTableExists(db *sql.DB) (bool, error) {
   163  	exists := false
   164  	err := db.QueryRow("SELECT 1 FROM sqlite_master "+
   165  		"WHERE type='table' AND name=$1",
   166  		drv.migrationsTableName).
   167  		Scan(&exists)
   168  	if err == sql.ErrNoRows {
   169  		return false, nil
   170  	}
   171  
   172  	return exists, err
   173  }
   174  
   175  // CreateMigrationsTable creates the schema migrations table
   176  func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
   177  	_, err := db.Exec(fmt.Sprintf(
   178  		"create table if not exists %s (version varchar(128) primary key)",
   179  		drv.quotedMigrationsTableName()))
   180  
   181  	return err
   182  }
   183  
   184  // SelectMigrations returns a list of applied migrations
   185  // with an optional limit (in descending order)
   186  func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
   187  	query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
   188  	if limit >= 0 {
   189  		query = fmt.Sprintf("%s limit %d", query, limit)
   190  	}
   191  	rows, err := db.Query(query)
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  
   196  	defer dbutil.MustClose(rows)
   197  
   198  	migrations := map[string]bool{}
   199  	for rows.Next() {
   200  		var version string
   201  		if err := rows.Scan(&version); err != nil {
   202  			return nil, err
   203  		}
   204  
   205  		migrations[version] = true
   206  	}
   207  
   208  	if err = rows.Err(); err != nil {
   209  		return nil, err
   210  	}
   211  
   212  	return migrations, nil
   213  }
   214  
   215  // InsertMigration adds a new migration record
   216  func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
   217  	_, err := db.Exec(
   218  		fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
   219  		version)
   220  
   221  	return err
   222  }
   223  
   224  // DeleteMigration removes a migration record
   225  func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
   226  	_, err := db.Exec(
   227  		fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
   228  		version)
   229  
   230  	return err
   231  }
   232  
   233  // Ping verifies a connection to the database. Due to the way SQLite works, by
   234  // testing whether the database is valid, it will automatically create the database
   235  // if it does not already exist.
   236  func (drv *Driver) Ping() error {
   237  	db, err := drv.Open()
   238  	if err != nil {
   239  		return err
   240  	}
   241  	defer dbutil.MustClose(db)
   242  
   243  	return db.Ping()
   244  }
   245  
   246  func (drv *Driver) quotedMigrationsTableName() string {
   247  	return drv.quoteIdentifier(drv.migrationsTableName)
   248  }
   249  
   250  // quoteIdentifier quotes a table or column name
   251  // we fall back to lib/pq implementation since both use ansi standard (double quotes)
   252  // and mattn/go-sqlite3 doesn't provide a sqlite-specific equivalent
   253  func (drv *Driver) quoteIdentifier(s string) string {
   254  	return pq.QuoteIdentifier(s)
   255  }