github.com/kubecost/golang-migrate-duckdb/v4@v4.17.0-duckdb.1/database/ql/ql.go (about)

     1  package ql
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	nurl "net/url"
     8  	"strings"
     9  
    10  	"github.com/hashicorp/go-multierror"
    11  	"go.uber.org/atomic"
    12  
    13  	"github.com/golang-migrate/migrate/v4"
    14  	"github.com/golang-migrate/migrate/v4/database"
    15  	_ "modernc.org/ql/driver"
    16  )
    17  
    18  func init() {
    19  	database.Register("ql", &Ql{})
    20  }
    21  
    22  var DefaultMigrationsTable = "schema_migrations"
    23  var (
    24  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    25  	ErrNilConfig      = fmt.Errorf("no config")
    26  	ErrNoDatabaseName = fmt.Errorf("no database name")
    27  	ErrAppendPEM      = fmt.Errorf("failed to append PEM")
    28  )
    29  
    30  type Config struct {
    31  	MigrationsTable string
    32  	DatabaseName    string
    33  }
    34  
    35  type Ql struct {
    36  	db       *sql.DB
    37  	isLocked atomic.Bool
    38  
    39  	config *Config
    40  }
    41  
    42  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    43  	if config == nil {
    44  		return nil, ErrNilConfig
    45  	}
    46  
    47  	if err := instance.Ping(); err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	if len(config.MigrationsTable) == 0 {
    52  		config.MigrationsTable = DefaultMigrationsTable
    53  	}
    54  
    55  	mx := &Ql{
    56  		db:     instance,
    57  		config: config,
    58  	}
    59  	if err := mx.ensureVersionTable(); err != nil {
    60  		return nil, err
    61  	}
    62  	return mx, nil
    63  }
    64  
    65  // ensureVersionTable checks if versions table exists and, if not, creates it.
    66  // Note that this function locks the database, which deviates from the usual
    67  // convention of "caller locks" in the Ql type.
    68  func (m *Ql) ensureVersionTable() (err error) {
    69  	if err = m.Lock(); err != nil {
    70  		return err
    71  	}
    72  
    73  	defer func() {
    74  		if e := m.Unlock(); e != nil {
    75  			if err == nil {
    76  				err = e
    77  			} else {
    78  				err = multierror.Append(err, e)
    79  			}
    80  		}
    81  	}()
    82  
    83  	tx, err := m.db.Begin()
    84  	if err != nil {
    85  		return err
    86  	}
    87  	if _, err := tx.Exec(fmt.Sprintf(`
    88  	CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool);
    89  	CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    90  `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil {
    91  		if err := tx.Rollback(); err != nil {
    92  			return err
    93  		}
    94  		return err
    95  	}
    96  	if err := tx.Commit(); err != nil {
    97  		return err
    98  	}
    99  	return nil
   100  }
   101  
   102  func (m *Ql) Open(url string) (database.Driver, error) {
   103  	purl, err := nurl.Parse(url)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1)
   108  	db, err := sql.Open("ql", dbfile)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	migrationsTable := purl.Query().Get("x-migrations-table")
   113  	if len(migrationsTable) == 0 {
   114  		migrationsTable = DefaultMigrationsTable
   115  	}
   116  	mx, err := WithInstance(db, &Config{
   117  		DatabaseName:    purl.Path,
   118  		MigrationsTable: migrationsTable,
   119  	})
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	return mx, nil
   124  }
   125  func (m *Ql) Close() error {
   126  	return m.db.Close()
   127  }
   128  func (m *Ql) Drop() (err error) {
   129  	query := `SELECT Name FROM __Table`
   130  	tables, err := m.db.Query(query)
   131  	if err != nil {
   132  		return &database.Error{OrigErr: err, Query: []byte(query)}
   133  	}
   134  	defer func() {
   135  		if errClose := tables.Close(); errClose != nil {
   136  			err = multierror.Append(err, errClose)
   137  		}
   138  	}()
   139  
   140  	tableNames := make([]string, 0)
   141  	for tables.Next() {
   142  		var tableName string
   143  		if err := tables.Scan(&tableName); err != nil {
   144  			return err
   145  		}
   146  		if len(tableName) > 0 {
   147  			if !strings.HasPrefix(tableName, "__") {
   148  				tableNames = append(tableNames, tableName)
   149  			}
   150  		}
   151  	}
   152  	if err := tables.Err(); err != nil {
   153  		return &database.Error{OrigErr: err, Query: []byte(query)}
   154  	}
   155  
   156  	if len(tableNames) > 0 {
   157  		for _, t := range tableNames {
   158  			query := "DROP TABLE " + t
   159  			err = m.executeQuery(query)
   160  			if err != nil {
   161  				return &database.Error{OrigErr: err, Query: []byte(query)}
   162  			}
   163  		}
   164  	}
   165  
   166  	return nil
   167  }
   168  func (m *Ql) Lock() error {
   169  	if !m.isLocked.CAS(false, true) {
   170  		return database.ErrLocked
   171  	}
   172  	return nil
   173  }
   174  func (m *Ql) Unlock() error {
   175  	if !m.isLocked.CAS(true, false) {
   176  		return database.ErrNotLocked
   177  	}
   178  	return nil
   179  }
   180  func (m *Ql) Run(migration io.Reader) error {
   181  	migr, err := io.ReadAll(migration)
   182  	if err != nil {
   183  		return err
   184  	}
   185  	query := string(migr[:])
   186  
   187  	return m.executeQuery(query)
   188  }
   189  func (m *Ql) executeQuery(query string) error {
   190  	tx, err := m.db.Begin()
   191  	if err != nil {
   192  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   193  	}
   194  	if _, err := tx.Exec(query); err != nil {
   195  		if errRollback := tx.Rollback(); errRollback != nil {
   196  			err = multierror.Append(err, errRollback)
   197  		}
   198  		return &database.Error{OrigErr: err, Query: []byte(query)}
   199  	}
   200  	if err := tx.Commit(); err != nil {
   201  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   202  	}
   203  	return nil
   204  }
   205  func (m *Ql) SetVersion(version int, dirty bool) error {
   206  	tx, err := m.db.Begin()
   207  	if err != nil {
   208  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   209  	}
   210  
   211  	query := "TRUNCATE TABLE " + m.config.MigrationsTable
   212  	if _, err := tx.Exec(query); err != nil {
   213  		return &database.Error{OrigErr: err, Query: []byte(query)}
   214  	}
   215  
   216  	// Also re-write the schema version for nil dirty versions to prevent
   217  	// empty schema version for failed down migration on the first migration
   218  	// See: https://github.com/golang-migrate/migrate/issues/330
   219  	if version >= 0 || (version == database.NilVersion && dirty) {
   220  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`,
   221  			m.config.MigrationsTable)
   222  		if _, err := tx.Exec(query, version, dirty); err != nil {
   223  			if errRollback := tx.Rollback(); errRollback != nil {
   224  				err = multierror.Append(err, errRollback)
   225  			}
   226  			return &database.Error{OrigErr: err, Query: []byte(query)}
   227  		}
   228  	}
   229  
   230  	if err := tx.Commit(); err != nil {
   231  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   232  	}
   233  
   234  	return nil
   235  }
   236  
   237  func (m *Ql) Version() (version int, dirty bool, err error) {
   238  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   239  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   240  	if err != nil {
   241  		return database.NilVersion, false, nil
   242  	}
   243  	return version, dirty, nil
   244  }