github.com/kubecost/golang-migrate-duckdb/v4@v4.17.0-duckdb.1/database/duckdb/duckdb.go (about)

     1  package duckdb
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	nurl "net/url"
     8  	"strings"
     9  
    10  	"go.uber.org/atomic"
    11  
    12  	"github.com/golang-migrate/migrate/v4"
    13  	"github.com/golang-migrate/migrate/v4/database"
    14  	"github.com/hashicorp/go-multierror"
    15  	_ "github.com/marcboeker/go-duckdb"
    16  )
    17  
    18  func init() {
    19  	database.Register("duckdb", &DuckDB{})
    20  }
    21  
    22  const MigrationTable = "gmg_schema_migrations"
    23  
    24  type DuckDB struct {
    25  	db       *sql.DB
    26  	isLocked atomic.Bool
    27  }
    28  
    29  func (d *DuckDB) Open(url string) (database.Driver, error) {
    30  	purl, err := nurl.Parse(url)
    31  	if err != nil {
    32  		return nil, fmt.Errorf("parsing url: %w", err)
    33  	}
    34  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "duckdb://", "", 1)
    35  	db, err := sql.Open("duckdb", dbfile)
    36  	if err != nil {
    37  		return nil, fmt.Errorf("opening '%s': %w", dbfile, err)
    38  	}
    39  
    40  	if err := db.Ping(); err != nil {
    41  		return nil, fmt.Errorf("pinging: %w", err)
    42  	}
    43  	d.db = db
    44  
    45  	if err := d.ensureVersionTable(); err != nil {
    46  		return nil, fmt.Errorf("ensuring version table: %w", err)
    47  	}
    48  
    49  	return d, nil
    50  }
    51  
    52  func (d *DuckDB) Close() error {
    53  	return d.db.Close()
    54  }
    55  
    56  func (d *DuckDB) Lock() error {
    57  	if !d.isLocked.CAS(false, true) {
    58  		return database.ErrLocked
    59  	}
    60  	return nil
    61  }
    62  
    63  func (d *DuckDB) Unlock() error {
    64  	if !d.isLocked.CAS(true, false) {
    65  		return database.ErrNotLocked
    66  	}
    67  	return nil
    68  }
    69  
    70  func (d *DuckDB) Drop() error {
    71  	// FIXME: implement
    72  	return fmt.Errorf("drop unimplemented because of duckdb size problems and not enough time during prototyping")
    73  }
    74  
    75  func (d *DuckDB) SetVersion(version int, dirty bool) error {
    76  	tx, err := d.db.Begin()
    77  	if err != nil {
    78  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
    79  	}
    80  
    81  	query := "DELETE FROM " + MigrationTable
    82  	if _, err := tx.Exec(query); err != nil {
    83  		return &database.Error{OrigErr: err, Query: []byte(query)}
    84  	}
    85  
    86  	// Also re-write the schema version for nil dirty versions to prevent
    87  	// empty schema version for failed down migration on the first migration
    88  	// See: https://github.com/golang-migrate/migrate/issues/330
    89  	//
    90  	// NOTE: Copied from sqlite implementation, unsure if this is necessary for
    91  	// duckdb
    92  	if version >= 0 || (version == database.NilVersion && dirty) {
    93  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, MigrationTable)
    94  		if _, err := tx.Exec(query, version, dirty); err != nil {
    95  			if errRollback := tx.Rollback(); errRollback != nil {
    96  				err = multierror.Append(err, errRollback)
    97  			}
    98  			return &database.Error{OrigErr: err, Query: []byte(query)}
    99  		}
   100  	}
   101  
   102  	if err := tx.Commit(); err != nil {
   103  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  func (m *DuckDB) Version() (version int, dirty bool, err error) {
   110  	query := "SELECT version, dirty FROM " + MigrationTable + " LIMIT 1"
   111  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   112  	if err != nil {
   113  		return database.NilVersion, false, nil
   114  	}
   115  	return version, dirty, nil
   116  }
   117  
   118  func (d *DuckDB) Run(migration io.Reader) error {
   119  	migr, err := io.ReadAll(migration)
   120  	if err != nil {
   121  		return fmt.Errorf("reading migration: %w", err)
   122  	}
   123  	query := string(migr[:])
   124  
   125  	tx, err := d.db.Begin()
   126  	if err != nil {
   127  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   128  	}
   129  	if _, err := tx.Exec(query); err != nil {
   130  		if errRollback := tx.Rollback(); errRollback != nil {
   131  			err = multierror.Append(err, errRollback)
   132  		}
   133  		return &database.Error{OrigErr: err, Query: []byte(query)}
   134  	}
   135  	if err := tx.Commit(); err != nil {
   136  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   137  	}
   138  	return nil
   139  }
   140  
   141  // ensureVersionTable checks if versions table exists and, if not, creates it.
   142  // Note that this function locks the database, which deviates from the usual
   143  // convention of "caller locks" in the Sqlite type.
   144  func (d *DuckDB) ensureVersionTable() (err error) {
   145  	if err = d.Lock(); err != nil {
   146  		return err
   147  	}
   148  
   149  	defer func() {
   150  		if e := d.Unlock(); e != nil {
   151  			if err == nil {
   152  				err = e
   153  			} else {
   154  				err = multierror.Append(err, e)
   155  			}
   156  		}
   157  	}()
   158  
   159  	query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (version BIGINT, dirty BOOLEAN);`, MigrationTable)
   160  
   161  	if _, err := d.db.Exec(query); err != nil {
   162  		return fmt.Errorf("creating version table via '%s': %w", query, err)
   163  	}
   164  	return nil
   165  }