github.com/paweljw/pop/v5@v5.4.6/dialect_sqlite.go (about)

     1  // +build sqlite
     2  
     3  package pop
     4  
     5  import (
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"io"
     9  	"net/url"
    10  	"os"
    11  	"os/exec"
    12  	"path/filepath"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/mattn/go-sqlite3"
    18  
    19  	"github.com/gobuffalo/fizz"
    20  	"github.com/gobuffalo/fizz/translators"
    21  	_ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver
    22  	"github.com/pkg/errors"
    23  
    24  	"github.com/gobuffalo/pop/v5/columns"
    25  	"github.com/gobuffalo/pop/v5/internal/defaults"
    26  	"github.com/gobuffalo/pop/v5/logging"
    27  )
    28  
    29  const nameSQLite3 = "sqlite3"
    30  
    31  func init() {
    32  	AvailableDialects = append(AvailableDialects, nameSQLite3)
    33  	dialectSynonyms["sqlite"] = nameSQLite3
    34  	urlParser[nameSQLite3] = urlParserSQLite3
    35  	newConnection[nameSQLite3] = newSQLite
    36  	finalizer[nameSQLite3] = finalizerSQLite
    37  }
    38  
    39  var _ dialect = &sqlite{}
    40  
    41  type sqlite struct {
    42  	commonDialect
    43  	gil   *sync.Mutex
    44  	smGil *sync.Mutex
    45  }
    46  
    47  func (m *sqlite) Name() string {
    48  	return nameSQLite3
    49  }
    50  
    51  func (m *sqlite) DefaultDriver() string {
    52  	return nameSQLite3
    53  }
    54  
    55  func (m *sqlite) Details() *ConnectionDetails {
    56  	return m.ConnectionDetails
    57  }
    58  
    59  func (m *sqlite) URL() string {
    60  	c := m.ConnectionDetails
    61  	return c.Database + "?" + c.OptionsString("")
    62  }
    63  
    64  func (m *sqlite) MigrationURL() string {
    65  	return m.ConnectionDetails.URL
    66  }
    67  
    68  func (m *sqlite) Create(s store, model *Model, cols columns.Columns) error {
    69  	return m.locker(m.smGil, func() error {
    70  		keyType, err := model.PrimaryKeyType()
    71  		if err != nil {
    72  			return err
    73  		}
    74  		switch keyType {
    75  		case "int", "int64":
    76  			var id int64
    77  			cols.Remove(model.IDField())
    78  			w := cols.Writeable()
    79  			var query string
    80  			if len(w.Cols) > 0 {
    81  				query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", m.Quote(model.TableName()), w.QuotedString(m), w.SymbolizedString())
    82  			} else {
    83  				query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", m.Quote(model.TableName()))
    84  			}
    85  			log(logging.SQL, query)
    86  			res, err := s.NamedExec(query, model.Value)
    87  			if err != nil {
    88  				return err
    89  			}
    90  			id, err = res.LastInsertId()
    91  			if err == nil {
    92  				model.setID(id)
    93  			}
    94  			if err != nil {
    95  				return err
    96  			}
    97  			return nil
    98  		}
    99  		return errors.Wrap(genericCreate(s, model, cols, m), "sqlite create")
   100  	})
   101  }
   102  
   103  func (m *sqlite) Update(s store, model *Model, cols columns.Columns) error {
   104  	return m.locker(m.smGil, func() error {
   105  		return errors.Wrap(genericUpdate(s, model, cols, m), "sqlite update")
   106  	})
   107  }
   108  
   109  func (m *sqlite) Destroy(s store, model *Model) error {
   110  	return m.locker(m.smGil, func() error {
   111  		return errors.Wrap(genericDestroy(s, model, m), "sqlite destroy")
   112  	})
   113  }
   114  
   115  func (m *sqlite) SelectOne(s store, model *Model, query Query) error {
   116  	return m.locker(m.smGil, func() error {
   117  		return errors.Wrap(genericSelectOne(s, model, query), "sqlite select one")
   118  	})
   119  }
   120  
   121  func (m *sqlite) SelectMany(s store, models *Model, query Query) error {
   122  	return m.locker(m.smGil, func() error {
   123  		return errors.Wrap(genericSelectMany(s, models, query), "sqlite select many")
   124  	})
   125  }
   126  
   127  func (m *sqlite) Lock(fn func() error) error {
   128  	return m.locker(m.gil, fn)
   129  }
   130  
   131  func (m *sqlite) locker(l *sync.Mutex, fn func() error) error {
   132  	if defaults.String(m.Details().Options["lock"], "true") == "true" {
   133  		defer l.Unlock()
   134  		l.Lock()
   135  	}
   136  	err := fn()
   137  	attempts := 0
   138  	for err != nil && err.Error() == "database is locked" && attempts <= m.Details().RetryLimit() {
   139  		time.Sleep(m.Details().RetrySleep())
   140  		err = fn()
   141  		attempts++
   142  	}
   143  	return err
   144  }
   145  
   146  func (m *sqlite) CreateDB() error {
   147  	_, err := os.Stat(m.ConnectionDetails.Database)
   148  	if err == nil {
   149  		return errors.Errorf("could not create SQLite database '%s'; database exists", m.ConnectionDetails.Database)
   150  	}
   151  	dir := filepath.Dir(m.ConnectionDetails.Database)
   152  	err = os.MkdirAll(dir, 0766)
   153  	if err != nil {
   154  		return errors.Wrapf(err, "could not create SQLite database '%s'", m.ConnectionDetails.Database)
   155  	}
   156  	_, err = os.Create(m.ConnectionDetails.Database)
   157  	if err != nil {
   158  		return errors.Wrapf(err, "could not create SQLite database '%s'", m.ConnectionDetails.Database)
   159  	}
   160  
   161  	log(logging.Info, "created database '%s'", m.ConnectionDetails.Database)
   162  	return nil
   163  }
   164  
   165  func (m *sqlite) DropDB() error {
   166  	err := os.Remove(m.ConnectionDetails.Database)
   167  	if err != nil {
   168  		return errors.Wrapf(err, "could not drop SQLite database %s", m.ConnectionDetails.Database)
   169  	}
   170  	log(logging.Info, "dropped database '%s'", m.ConnectionDetails.Database)
   171  	return nil
   172  }
   173  
   174  func (m *sqlite) TranslateSQL(sql string) string {
   175  	return sql
   176  }
   177  
   178  func (m *sqlite) FizzTranslator() fizz.Translator {
   179  	return translators.NewSQLite(m.Details().Database)
   180  }
   181  
   182  func (m *sqlite) DumpSchema(w io.Writer) error {
   183  	cmd := exec.Command("sqlite3", m.Details().Database, ".schema")
   184  	return genericDumpSchema(m.Details(), cmd, w)
   185  }
   186  
   187  func (m *sqlite) LoadSchema(r io.Reader) error {
   188  	cmd := exec.Command("sqlite3", m.ConnectionDetails.Database)
   189  	in, err := cmd.StdinPipe()
   190  	if err != nil {
   191  		return err
   192  	}
   193  	go func() {
   194  		defer in.Close()
   195  		io.Copy(in, r)
   196  	}()
   197  	log(logging.SQL, strings.Join(cmd.Args, " "))
   198  	err = cmd.Start()
   199  	if err != nil {
   200  		return err
   201  	}
   202  
   203  	err = cmd.Wait()
   204  	if err != nil {
   205  		return err
   206  	}
   207  
   208  	log(logging.Info, "loaded schema for %s", m.Details().Database)
   209  	return nil
   210  }
   211  
   212  func (m *sqlite) TruncateAll(tx *Connection) error {
   213  	const tableNames = `SELECT name FROM sqlite_master WHERE type = "table"`
   214  	names := []struct {
   215  		Name string `db:"name"`
   216  	}{}
   217  
   218  	err := tx.RawQuery(tableNames).All(&names)
   219  	if err != nil {
   220  		return err
   221  	}
   222  	if len(names) == 0 {
   223  		return nil
   224  	}
   225  	stmts := []string{}
   226  	for _, n := range names {
   227  		stmts = append(stmts, fmt.Sprintf("DELETE FROM %s", m.Quote(n.Name)))
   228  	}
   229  	return tx.RawQuery(strings.Join(stmts, "; ")).Exec()
   230  }
   231  
   232  func newSQLite(deets *ConnectionDetails) (dialect, error) {
   233  	deets.URL = fmt.Sprintf("sqlite3://%s", deets.Database)
   234  	cd := &sqlite{
   235  		gil:           &sync.Mutex{},
   236  		smGil:         &sync.Mutex{},
   237  		commonDialect: commonDialect{ConnectionDetails: deets},
   238  	}
   239  
   240  	return cd, nil
   241  }
   242  
   243  func urlParserSQLite3(cd *ConnectionDetails) error {
   244  	db := strings.TrimPrefix(cd.URL, "sqlite://")
   245  	db = strings.TrimPrefix(db, "sqlite3://")
   246  
   247  	dbparts := strings.Split(db, "?")
   248  	cd.Database = dbparts[0]
   249  
   250  	if len(dbparts) != 2 {
   251  		return nil
   252  	}
   253  
   254  	q, err := url.ParseQuery(dbparts[1])
   255  	if err != nil {
   256  		return errors.Wrapf(err, "unable to parse sqlite query")
   257  	}
   258  
   259  	if cd.Options == nil { // prevent panic
   260  		cd.Options = make(map[string]string)
   261  	}
   262  	for k := range q {
   263  		cd.Options[k] = q.Get(k)
   264  	}
   265  
   266  	return nil
   267  }
   268  
   269  func finalizerSQLite(cd *ConnectionDetails) {
   270  	defs := map[string]string{
   271  		"_busy_timeout": "5000",
   272  	}
   273  	forced := map[string]string{
   274  		"_fk": "true",
   275  	}
   276  	if cd.Options == nil { // prevent panic
   277  		cd.Options = make(map[string]string)
   278  	}
   279  
   280  	for k, v := range defs {
   281  		cd.Options[k] = defaults.String(cd.Options[k], v)
   282  	}
   283  
   284  	for k, v := range forced {
   285  		// respect user specified options but print warning!
   286  		cd.Options[k] = defaults.String(cd.Options[k], v)
   287  		if cd.Options[k] != v { // when user-defined option exists
   288  			log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.Options[k])
   289  			log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.Options[k])
   290  		} // or override with `cd.Options[k] = v`?
   291  		if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
   292  			log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
   293  		} // or fix user specified url?
   294  	}
   295  }
   296  
   297  func newSQLiteDriver() (driver.Driver, error) {
   298  	return new(sqlite3.SQLiteDriver), nil
   299  }