github.com/goravel/framework@v1.13.9/database/console/driver/sqlite.go (about)

     1  package sqlite
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	nurl "net/url"
     8  	"strconv"
     9  	"strings"
    10  
    11  	_ "github.com/glebarez/go-sqlite"
    12  	"github.com/golang-migrate/migrate/v4"
    13  	"github.com/golang-migrate/migrate/v4/database"
    14  	"github.com/hashicorp/go-multierror"
    15  	"go.uber.org/atomic"
    16  )
    17  
    18  func init() {
    19  	database.Register("sqlite", &Sqlite{})
    20  }
    21  
    22  var DefaultMigrationsTable = "schema_migrations"
    23  var (
    24  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    25  	ErrNilConfig      = fmt.Errorf("no config")
    26  	ErrNoDatabaseName = fmt.Errorf("no database name")
    27  )
    28  
    29  type Config struct {
    30  	MigrationsTable string
    31  	DatabaseName    string
    32  	NoTxWrap        bool
    33  }
    34  
    35  type Sqlite struct {
    36  	db       *sql.DB
    37  	isLocked atomic.Bool
    38  
    39  	config *Config
    40  }
    41  
    42  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    43  	if config == nil {
    44  		return nil, ErrNilConfig
    45  	}
    46  
    47  	if err := instance.Ping(); err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	if len(config.MigrationsTable) == 0 {
    52  		config.MigrationsTable = DefaultMigrationsTable
    53  	}
    54  
    55  	mx := &Sqlite{
    56  		db:     instance,
    57  		config: config,
    58  	}
    59  	if err := mx.ensureVersionTable(); err != nil {
    60  		return nil, err
    61  	}
    62  	return mx, nil
    63  }
    64  
    65  // ensureVersionTable checks if versions table exists and, if not, creates it.
    66  // Note that this function locks the database, which deviates from the usual
    67  // convention of "caller locks" in the Sqlite type.
    68  func (m *Sqlite) ensureVersionTable() (err error) {
    69  	if err = m.Lock(); err != nil {
    70  		return err
    71  	}
    72  
    73  	defer func() {
    74  		if e := m.Unlock(); e != nil {
    75  			if err == nil {
    76  				err = e
    77  			} else {
    78  				err = multierror.Append(err, e)
    79  			}
    80  		}
    81  	}()
    82  
    83  	query := fmt.Sprintf(`
    84  	CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
    85    CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    86    `, m.config.MigrationsTable, m.config.MigrationsTable)
    87  
    88  	if _, err = m.db.Exec(query); err != nil {
    89  		return err
    90  	}
    91  	return nil
    92  }
    93  
    94  func (m *Sqlite) Open(url string) (database.Driver, error) {
    95  	purl, err := nurl.Parse(url)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite://", "", 1)
   100  	db, err := sql.Open("sqlite", dbfile)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	qv := purl.Query()
   106  
   107  	migrationsTable := qv.Get("x-migrations-table")
   108  	if len(migrationsTable) == 0 {
   109  		migrationsTable = DefaultMigrationsTable
   110  	}
   111  
   112  	noTxWrap := false
   113  	if v := qv.Get("x-no-tx-wrap"); v != "" {
   114  		noTxWrap, err = strconv.ParseBool(v)
   115  		if err != nil {
   116  			return nil, fmt.Errorf("x-no-tx-wrap: %s", err)
   117  		}
   118  	}
   119  
   120  	mx, err := WithInstance(db, &Config{
   121  		DatabaseName:    purl.Path,
   122  		MigrationsTable: migrationsTable,
   123  		NoTxWrap:        noTxWrap,
   124  	})
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  	return mx, nil
   129  }
   130  
   131  func (m *Sqlite) Close() error {
   132  	return m.db.Close()
   133  }
   134  
   135  func (m *Sqlite) Drop() (err error) {
   136  	query := `SELECT name FROM sqlite_master WHERE type = 'table';`
   137  	tables, err := m.db.Query(query)
   138  	if err != nil {
   139  		return &database.Error{OrigErr: err, Query: []byte(query)}
   140  	}
   141  	defer func() {
   142  		if errClose := tables.Close(); errClose != nil {
   143  			err = multierror.Append(err, errClose)
   144  		}
   145  	}()
   146  
   147  	tableNames := make([]string, 0)
   148  	for tables.Next() {
   149  		var tableName string
   150  		if err = tables.Scan(&tableName); err != nil {
   151  			return err
   152  		}
   153  		if len(tableName) > 0 {
   154  			tableNames = append(tableNames, tableName)
   155  		}
   156  	}
   157  	if err = tables.Err(); err != nil {
   158  		return &database.Error{OrigErr: err, Query: []byte(query)}
   159  	}
   160  
   161  	if len(tableNames) > 0 {
   162  		for _, t := range tableNames {
   163  			// SQLite has a sqlite_sequence table and it cannot be dropped
   164  			if t == "sqlite_sequence" {
   165  				_, err = m.db.Exec("DELETE FROM sqlite_sequence;")
   166  				if err != nil {
   167  					return &database.Error{OrigErr: err, Query: []byte("DELETE FROM sqlite_sequence;")}
   168  				}
   169  
   170  				continue
   171  			}
   172  
   173  			query = "DROP TABLE " + t
   174  			err = m.executeQuery(query)
   175  			if err != nil {
   176  				return &database.Error{OrigErr: err, Query: []byte(query)}
   177  			}
   178  		}
   179  		query = "VACUUM"
   180  		_, err = m.db.Exec(query)
   181  		if err != nil {
   182  			return &database.Error{OrigErr: err, Query: []byte(query)}
   183  		}
   184  	}
   185  
   186  	return nil
   187  }
   188  
   189  func (m *Sqlite) Lock() error {
   190  	if !m.isLocked.CompareAndSwap(false, true) {
   191  		return database.ErrLocked
   192  	}
   193  	return nil
   194  }
   195  
   196  func (m *Sqlite) Unlock() error {
   197  	if !m.isLocked.CompareAndSwap(true, false) {
   198  		return database.ErrNotLocked
   199  	}
   200  	return nil
   201  }
   202  
   203  func (m *Sqlite) Run(migration io.Reader) error {
   204  	migr, err := io.ReadAll(migration)
   205  	if err != nil {
   206  		return err
   207  	}
   208  	query := string(migr[:])
   209  
   210  	if m.config.NoTxWrap {
   211  		return m.executeQueryNoTx(query)
   212  	}
   213  	return m.executeQuery(query)
   214  }
   215  
   216  func (m *Sqlite) executeQuery(query string) error {
   217  	tx, err := m.db.Begin()
   218  	if err != nil {
   219  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   220  	}
   221  	if _, err = tx.Exec(query); err != nil {
   222  		if errRollback := tx.Rollback(); errRollback != nil {
   223  			err = multierror.Append(err, errRollback)
   224  		}
   225  		return &database.Error{OrigErr: err, Query: []byte(query)}
   226  	}
   227  	if err = tx.Commit(); err != nil {
   228  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   229  	}
   230  	return nil
   231  }
   232  
   233  func (m *Sqlite) executeQueryNoTx(query string) error {
   234  	if _, err := m.db.Exec(query); err != nil {
   235  		return &database.Error{OrigErr: err, Query: []byte(query)}
   236  	}
   237  	return nil
   238  }
   239  
   240  func (m *Sqlite) SetVersion(version int, dirty bool) error {
   241  	tx, err := m.db.Begin()
   242  	if err != nil {
   243  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   244  	}
   245  
   246  	query := "DELETE FROM " + m.config.MigrationsTable
   247  	if _, err = tx.Exec(query); err != nil {
   248  		return &database.Error{OrigErr: err, Query: []byte(query)}
   249  	}
   250  
   251  	// Also re-write the schema version for nil dirty versions to prevent
   252  	// empty schema version for failed down migration on the first migration
   253  	// See: https://github.com/golang-migrate/migrate/issues/330
   254  	if version >= 0 || (version == database.NilVersion && dirty) {
   255  		query = fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable)
   256  		if _, err = tx.Exec(query, version, dirty); err != nil {
   257  			if errRollback := tx.Rollback(); errRollback != nil {
   258  				err = multierror.Append(err, errRollback)
   259  			}
   260  			return &database.Error{OrigErr: err, Query: []byte(query)}
   261  		}
   262  	}
   263  
   264  	if err = tx.Commit(); err != nil {
   265  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   266  	}
   267  
   268  	return nil
   269  }
   270  
   271  func (m *Sqlite) Version() (version int, dirty bool, err error) {
   272  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   273  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   274  	if err != nil {
   275  		return database.NilVersion, false, nil
   276  	}
   277  	return version, dirty, nil
   278  }