github.com/friesencr/pop/v6@v6.1.6/dialect_sqlite.go (about)

     1  package pop
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/url"
    10  	"os"
    11  	"os/exec"
    12  	"path/filepath"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/friesencr/pop/v6/columns"
    18  	"github.com/friesencr/pop/v6/internal/defaults"
    19  	"github.com/friesencr/pop/v6/logging"
    20  	"github.com/gobuffalo/fizz"
    21  	"github.com/gobuffalo/fizz/translators"
    22  	"github.com/jmoiron/sqlx"
    23  )
    24  
    25  const nameSQLite3 = "sqlite3"
    26  
    27  func init() {
    28  	AvailableDialects = append(AvailableDialects, nameSQLite3)
    29  	dialectSynonyms["sqlite"] = nameSQLite3
    30  	urlParser[nameSQLite3] = urlParserSQLite3
    31  	newConnection[nameSQLite3] = newSQLite
    32  	finalizer[nameSQLite3] = finalizerSQLite
    33  }
    34  
    35  var _ dialect = &sqlite{}
    36  
    37  type sqlite struct {
    38  	commonDialect
    39  	gil   *sync.Mutex
    40  	smGil *sync.Mutex
    41  }
    42  
    43  func requireSQLite3() error {
    44  	for _, driverName := range sql.Drivers() {
    45  		if driverName == nameSQLite3 {
    46  			return nil
    47  		}
    48  	}
    49  	return errors.New("sqlite3 support was not compiled into the binary")
    50  }
    51  
    52  func (m *sqlite) Name() string {
    53  	return nameSQLite3
    54  }
    55  
    56  func (m *sqlite) DefaultDriver() string {
    57  	return nameSQLite3
    58  }
    59  
    60  func (m *sqlite) Details() *ConnectionDetails {
    61  	return m.ConnectionDetails
    62  }
    63  
    64  func (m *sqlite) URL() string {
    65  	c := m.ConnectionDetails
    66  	return c.Database + "?" + c.OptionsString("")
    67  }
    68  
    69  func (m *sqlite) MigrationURL() string {
    70  	return m.ConnectionDetails.URL
    71  }
    72  
    73  func (m *sqlite) Create(c *Connection, model *Model, cols columns.Columns) error {
    74  	return m.locker(m.smGil, func() error {
    75  		keyType, err := model.PrimaryKeyType()
    76  		if err != nil {
    77  			return err
    78  		}
    79  		switch keyType {
    80  		case "int", "int64":
    81  			var id int64
    82  			cols.Remove(model.IDField())
    83  			w := cols.Writeable()
    84  			var query string
    85  			if len(w.Cols) > 0 {
    86  				query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", m.Quote(model.TableName()), w.QuotedString(m), w.SymbolizedString())
    87  			} else {
    88  				query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", m.Quote(model.TableName()))
    89  			}
    90  			txlog(logging.SQL, c, query, model.Value)
    91  			res, err := c.Store.NamedExecContext(model.ctx, query, model.Value)
    92  			if err != nil {
    93  				return err
    94  			}
    95  			id, err = res.LastInsertId()
    96  			if err == nil {
    97  				model.setID(id)
    98  			}
    99  			if err != nil {
   100  				return err
   101  			}
   102  			return nil
   103  		}
   104  		if err := genericCreate(c, model, cols, m); err != nil {
   105  			return fmt.Errorf("sqlite create: %w", err)
   106  		}
   107  		return nil
   108  	})
   109  }
   110  
   111  func (m *sqlite) Update(c *Connection, model *Model, cols columns.Columns) error {
   112  	return m.locker(m.smGil, func() error {
   113  		if err := genericUpdate(c, model, cols, m); err != nil {
   114  			return fmt.Errorf("sqlite update: %w", err)
   115  		}
   116  		return nil
   117  	})
   118  }
   119  
   120  func (m *sqlite) UpdateQuery(c *Connection, model *Model, cols columns.Columns, query Query) (int64, error) {
   121  	rowsAffected := int64(0)
   122  	err := m.locker(m.smGil, func() error {
   123  		if n, err := genericUpdateQuery(c, model, cols, m, query, sqlx.QUESTION); err != nil {
   124  			rowsAffected = n
   125  			return fmt.Errorf("sqlite update query: %w", err)
   126  		} else {
   127  			rowsAffected = n
   128  			return nil
   129  		}
   130  	})
   131  	return rowsAffected, err
   132  }
   133  
   134  func (m *sqlite) Destroy(c *Connection, model *Model) error {
   135  	return m.locker(m.smGil, func() error {
   136  		if err := genericDestroy(c, model, m); err != nil {
   137  			return fmt.Errorf("sqlite destroy: %w", err)
   138  		}
   139  		return nil
   140  	})
   141  }
   142  
   143  func (m *sqlite) Delete(c *Connection, model *Model, query Query) error {
   144  	return genericDelete(c, model, query)
   145  }
   146  
   147  func (m *sqlite) SelectOne(c *Connection, model *Model, query Query) error {
   148  	return m.locker(m.smGil, func() error {
   149  		if err := genericSelectOne(c, model, query); err != nil {
   150  			return fmt.Errorf("sqlite select one: %w", err)
   151  		}
   152  		return nil
   153  	})
   154  }
   155  
   156  func (m *sqlite) SelectMany(c *Connection, models *Model, query Query) error {
   157  	return m.locker(m.smGil, func() error {
   158  		if err := genericSelectMany(c, models, query); err != nil {
   159  			return fmt.Errorf("sqlite select many: %w", err)
   160  		}
   161  		return nil
   162  	})
   163  }
   164  
   165  func (m *sqlite) Lock(fn func() error) error {
   166  	return m.locker(m.gil, fn)
   167  }
   168  
   169  func (m *sqlite) locker(l *sync.Mutex, fn func() error) error {
   170  	if defaults.String(m.Details().option("lock"), "true") == "true" {
   171  		defer l.Unlock()
   172  		l.Lock()
   173  	}
   174  	err := fn()
   175  	attempts := 0
   176  	for err != nil && err.Error() == "database is locked" && attempts <= m.Details().RetryLimit() {
   177  		time.Sleep(m.Details().RetrySleep())
   178  		err = fn()
   179  		attempts++
   180  	}
   181  	return err
   182  }
   183  
   184  func (m *sqlite) CreateDB() error {
   185  	durl := m.ConnectionDetails.Database
   186  
   187  	// Checking whether the url specifies in-memory mode
   188  	// as specified in https://github.com/mattn/go-sqlite3#faq
   189  	if strings.Contains(durl, ":memory:") || strings.Contains(durl, "mode=memory") {
   190  		log(logging.Info, "in memory db selected, no database file created.")
   191  
   192  		return nil
   193  	}
   194  
   195  	_, err := os.Stat(durl)
   196  	if err == nil {
   197  		return fmt.Errorf("could not create SQLite database '%s'; database exists", durl)
   198  	}
   199  	dir := filepath.Dir(durl)
   200  	err = os.MkdirAll(dir, 0766)
   201  	if err != nil {
   202  		return fmt.Errorf("could not create SQLite database '%s': %w", durl, err)
   203  	}
   204  	f, err := os.Create(durl)
   205  	if err != nil {
   206  		return fmt.Errorf("could not create SQLite database '%s': %w", durl, err)
   207  	}
   208  	_ = f.Close()
   209  
   210  	log(logging.Info, "created database '%s'", durl)
   211  	return nil
   212  }
   213  
   214  func (m *sqlite) DropDB() error {
   215  	err := os.Remove(m.ConnectionDetails.Database)
   216  	if err != nil {
   217  		return fmt.Errorf("could not drop SQLite database %s: %w", m.ConnectionDetails.Database, err)
   218  	}
   219  	log(logging.Info, "dropped database '%s'", m.ConnectionDetails.Database)
   220  	return nil
   221  }
   222  
   223  func (m *sqlite) TranslateSQL(sql string) string {
   224  	return sql
   225  }
   226  
   227  func (m *sqlite) FizzTranslator() fizz.Translator {
   228  	return translators.NewSQLite(m.Details().Database)
   229  }
   230  
   231  func (m *sqlite) DumpSchema(w io.Writer) error {
   232  	cmd := exec.Command("sqlite3", m.Details().Database, ".schema")
   233  	return genericDumpSchema(m.Details(), cmd, w)
   234  }
   235  
   236  func (m *sqlite) LoadSchema(r io.Reader) error {
   237  	cmd := exec.Command("sqlite3", m.ConnectionDetails.Database)
   238  	in, err := cmd.StdinPipe()
   239  	if err != nil {
   240  		return err
   241  	}
   242  	go func() {
   243  		defer in.Close()
   244  		io.Copy(in, r)
   245  	}()
   246  	log(logging.SQL, strings.Join(cmd.Args, " "))
   247  	err = cmd.Start()
   248  	if err != nil {
   249  		return err
   250  	}
   251  
   252  	err = cmd.Wait()
   253  	if err != nil {
   254  		return err
   255  	}
   256  
   257  	log(logging.Info, "loaded schema for %s", m.Details().Database)
   258  	return nil
   259  }
   260  
   261  func (m *sqlite) TruncateAll(tx *Connection) error {
   262  	const tableNames = `SELECT name FROM sqlite_master WHERE type = "table"`
   263  	names := []struct {
   264  		Name string `db:"name"`
   265  	}{}
   266  
   267  	err := tx.RawQuery(tableNames).All(&names)
   268  	if err != nil {
   269  		return err
   270  	}
   271  	if len(names) == 0 {
   272  		return nil
   273  	}
   274  	stmts := []string{}
   275  	for _, n := range names {
   276  		stmts = append(stmts, fmt.Sprintf("DELETE FROM %s", m.Quote(n.Name)))
   277  	}
   278  	return tx.RawQuery(strings.Join(stmts, "; ")).Exec()
   279  }
   280  
   281  func newSQLite(deets *ConnectionDetails) (dialect, error) {
   282  	err := requireSQLite3()
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  	deets.URL = fmt.Sprintf("sqlite3://%s", deets.Database)
   287  	cd := &sqlite{
   288  		gil:           &sync.Mutex{},
   289  		smGil:         &sync.Mutex{},
   290  		commonDialect: commonDialect{ConnectionDetails: deets},
   291  	}
   292  
   293  	return cd, nil
   294  }
   295  
   296  func urlParserSQLite3(cd *ConnectionDetails) error {
   297  	db := strings.TrimPrefix(cd.URL, "sqlite://")
   298  	db = strings.TrimPrefix(db, "sqlite3://")
   299  
   300  	dbparts := strings.Split(db, "?")
   301  	cd.Database = dbparts[0]
   302  
   303  	if len(dbparts) != 2 {
   304  		return nil
   305  	}
   306  
   307  	q, err := url.ParseQuery(dbparts[1])
   308  	if err != nil {
   309  		return fmt.Errorf("unable to parse sqlite query: %w", err)
   310  	}
   311  
   312  	for k := range q {
   313  		cd.setOption(k, q.Get(k))
   314  	}
   315  
   316  	return nil
   317  }
   318  
   319  func finalizerSQLite(cd *ConnectionDetails) {
   320  	defs := map[string]string{
   321  		"_busy_timeout": "5000",
   322  	}
   323  	forced := map[string]string{
   324  		"_fk": "true",
   325  	}
   326  
   327  	for k, def := range defs {
   328  		cd.setOptionWithDefault(k, cd.option(k), def)
   329  	}
   330  
   331  	for k, v := range forced {
   332  		// respect user specified options but print warning!
   333  		cd.setOptionWithDefault(k, cd.option(k), v)
   334  		if cd.option(k) != v { // when user-defined option exists
   335  			log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.option(k))
   336  			log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.option(k))
   337  		} // or override with `cd.Options[k] = v`?
   338  		if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
   339  			log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
   340  		} // or fix user specified url?
   341  	}
   342  }
   343  
   344  func newSQLiteDriver() (driver.Driver, error) {
   345  	err := requireSQLite3()
   346  	if err != nil {
   347  		return nil, err
   348  	}
   349  	db, err := sql.Open(nameSQLite3, ":memory:?cache=newSQLiteDriver_temporary")
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  	return db.Driver(), db.Close()
   354  }