github.com/naemono/pop@v4.13.1+incompatible/dialect_sqlite.go (about)

     1  // +build sqlite
     2  
     3  package pop
     4  
     5  import (
     6  	"fmt"
     7  	"io"
     8  	"net/url"
     9  	"os"
    10  	"os/exec"
    11  	"path/filepath"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/gobuffalo/fizz"
    17  	"github.com/gobuffalo/fizz/translators"
    18  	_ "github.com/mattn/go-sqlite3" // Load SQLite3 CGo driver
    19  	"github.com/pkg/errors"
    20  
    21  	"github.com/gobuffalo/pop/columns"
    22  	"github.com/gobuffalo/pop/internal/defaults"
    23  	"github.com/gobuffalo/pop/logging"
    24  )
    25  
    26  const nameSQLite3 = "sqlite3"
    27  
    28  func init() {
    29  	AvailableDialects = append(AvailableDialects, nameSQLite3)
    30  	dialectSynonyms["sqlite"] = nameSQLite3
    31  	urlParser[nameSQLite3] = urlParserSQLite3
    32  	newConnection[nameSQLite3] = newSQLite
    33  	finalizer[nameSQLite3] = finalizerSQLite
    34  }
    35  
    36  var _ dialect = &sqlite{}
    37  
    38  type sqlite struct {
    39  	commonDialect
    40  	gil   *sync.Mutex
    41  	smGil *sync.Mutex
    42  }
    43  
    44  func (m *sqlite) Name() string {
    45  	return nameSQLite3
    46  }
    47  
    48  func (m *sqlite) Details() *ConnectionDetails {
    49  	return m.ConnectionDetails
    50  }
    51  
    52  func (m *sqlite) URL() string {
    53  	c := m.ConnectionDetails
    54  	return c.Database + "?" + c.OptionsString("")
    55  }
    56  
    57  func (m *sqlite) MigrationURL() string {
    58  	return m.ConnectionDetails.URL
    59  }
    60  
    61  func (m *sqlite) Create(s store, model *Model, cols columns.Columns) error {
    62  	return m.locker(m.smGil, func() error {
    63  		keyType := model.PrimaryKeyType()
    64  		switch keyType {
    65  		case "int", "int64":
    66  			var id int64
    67  			w := cols.Writeable()
    68  			var query string
    69  			if len(w.Cols) > 0 {
    70  				query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", m.Quote(model.TableName()), w.QuotedString(m), w.SymbolizedString())
    71  			} else {
    72  				query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", m.Quote(model.TableName()))
    73  			}
    74  			log(logging.SQL, query)
    75  			res, err := s.NamedExec(query, model.Value)
    76  			if err != nil {
    77  				return err
    78  			}
    79  			id, err = res.LastInsertId()
    80  			if err == nil {
    81  				model.setID(id)
    82  			}
    83  			if err != nil {
    84  				return err
    85  			}
    86  			return nil
    87  		}
    88  		return errors.Wrap(genericCreate(s, model, cols, m), "sqlite create")
    89  	})
    90  }
    91  
    92  func (m *sqlite) Update(s store, model *Model, cols columns.Columns) error {
    93  	return m.locker(m.smGil, func() error {
    94  		return errors.Wrap(genericUpdate(s, model, cols, m), "sqlite update")
    95  	})
    96  }
    97  
    98  func (m *sqlite) Destroy(s store, model *Model) error {
    99  	return m.locker(m.smGil, func() error {
   100  		return errors.Wrap(genericDestroy(s, model, m), "sqlite destroy")
   101  	})
   102  }
   103  
   104  func (m *sqlite) SelectOne(s store, model *Model, query Query) error {
   105  	return m.locker(m.smGil, func() error {
   106  		return errors.Wrap(genericSelectOne(s, model, query), "sqlite select one")
   107  	})
   108  }
   109  
   110  func (m *sqlite) SelectMany(s store, models *Model, query Query) error {
   111  	return m.locker(m.smGil, func() error {
   112  		return errors.Wrap(genericSelectMany(s, models, query), "sqlite select many")
   113  	})
   114  }
   115  
   116  func (m *sqlite) Lock(fn func() error) error {
   117  	return m.locker(m.gil, fn)
   118  }
   119  
   120  func (m *sqlite) locker(l *sync.Mutex, fn func() error) error {
   121  	if defaults.String(m.Details().Options["lock"], "true") == "true" {
   122  		defer l.Unlock()
   123  		l.Lock()
   124  	}
   125  	err := fn()
   126  	attempts := 0
   127  	for err != nil && err.Error() == "database is locked" && attempts <= m.Details().RetryLimit() {
   128  		time.Sleep(m.Details().RetrySleep())
   129  		err = fn()
   130  		attempts++
   131  	}
   132  	return err
   133  }
   134  
   135  func (m *sqlite) CreateDB() error {
   136  	d := filepath.Dir(m.ConnectionDetails.Database)
   137  	err := os.MkdirAll(d, 0766)
   138  	if err != nil {
   139  		return errors.Wrapf(err, "could not create SQLite database %s", m.ConnectionDetails.Database)
   140  	}
   141  	_, err = os.Create(m.ConnectionDetails.Database)
   142  	if err != nil {
   143  		return errors.Wrapf(err, "could not create SQLite database %s", m.ConnectionDetails.Database)
   144  	}
   145  
   146  	log(logging.Info, "created database %s", m.ConnectionDetails.Database)
   147  	return nil
   148  }
   149  
   150  func (m *sqlite) DropDB() error {
   151  	err := os.Remove(m.ConnectionDetails.Database)
   152  	if err != nil {
   153  		return errors.Wrapf(err, "could not drop SQLite database %s", m.ConnectionDetails.Database)
   154  	}
   155  	log(logging.Info, "dropped database %s", m.ConnectionDetails.Database)
   156  	return nil
   157  }
   158  
   159  func (m *sqlite) TranslateSQL(sql string) string {
   160  	return sql
   161  }
   162  
   163  func (m *sqlite) FizzTranslator() fizz.Translator {
   164  	return translators.NewSQLite(m.Details().Database)
   165  }
   166  
   167  func (m *sqlite) DumpSchema(w io.Writer) error {
   168  	cmd := exec.Command("sqlite3", m.Details().Database, ".schema")
   169  	return genericDumpSchema(m.Details(), cmd, w)
   170  }
   171  
   172  func (m *sqlite) LoadSchema(r io.Reader) error {
   173  	cmd := exec.Command("sqlite3", m.ConnectionDetails.Database)
   174  	in, err := cmd.StdinPipe()
   175  	if err != nil {
   176  		return err
   177  	}
   178  	go func() {
   179  		defer in.Close()
   180  		io.Copy(in, r)
   181  	}()
   182  	log(logging.SQL, strings.Join(cmd.Args, " "))
   183  	err = cmd.Start()
   184  	if err != nil {
   185  		return err
   186  	}
   187  
   188  	err = cmd.Wait()
   189  	if err != nil {
   190  		return err
   191  	}
   192  
   193  	log(logging.Info, "loaded schema for %s", m.Details().Database)
   194  	return nil
   195  }
   196  
   197  func (m *sqlite) TruncateAll(tx *Connection) error {
   198  	const tableNames = `SELECT name FROM sqlite_master WHERE type = "table"`
   199  	names := []struct {
   200  		Name string `db:"name"`
   201  	}{}
   202  
   203  	err := tx.RawQuery(tableNames).All(&names)
   204  	if err != nil {
   205  		return err
   206  	}
   207  	if len(names) == 0 {
   208  		return nil
   209  	}
   210  	stmts := []string{}
   211  	for _, n := range names {
   212  		stmts = append(stmts, fmt.Sprintf("DELETE FROM %s", m.Quote(n.Name)))
   213  	}
   214  	return tx.RawQuery(strings.Join(stmts, "; ")).Exec()
   215  }
   216  
   217  func newSQLite(deets *ConnectionDetails) (dialect, error) {
   218  	deets.URL = fmt.Sprintf("sqlite3://%s", deets.Database)
   219  	cd := &sqlite{
   220  		gil:           &sync.Mutex{},
   221  		smGil:         &sync.Mutex{},
   222  		commonDialect: commonDialect{ConnectionDetails: deets},
   223  	}
   224  
   225  	return cd, nil
   226  }
   227  
   228  func urlParserSQLite3(cd *ConnectionDetails) error {
   229  	db := strings.TrimPrefix(cd.URL, "sqlite://")
   230  	db = strings.TrimPrefix(db, "sqlite3://")
   231  
   232  	dbparts := strings.Split(db, "?")
   233  	cd.Database = dbparts[0]
   234  
   235  	if len(dbparts) != 2 {
   236  		return nil
   237  	}
   238  
   239  	q, err := url.ParseQuery(dbparts[1])
   240  	if err != nil {
   241  		return errors.Wrapf(err, "unable to parse sqlite query")
   242  	}
   243  
   244  	if cd.Options == nil { // prevent panic
   245  		cd.Options = make(map[string]string)
   246  	}
   247  	for k := range q {
   248  		cd.Options[k] = q.Get(k)
   249  	}
   250  
   251  	return nil
   252  }
   253  
   254  func finalizerSQLite(cd *ConnectionDetails) {
   255  	defs := map[string]string{
   256  		"_busy_timeout": "5000",
   257  	}
   258  	forced := map[string]string{
   259  		"_fk": "true",
   260  	}
   261  	if cd.Options == nil { // prevent panic
   262  		cd.Options = make(map[string]string)
   263  	}
   264  
   265  	for k, v := range defs {
   266  		cd.Options[k] = defaults.String(cd.Options[k], v)
   267  	}
   268  
   269  	for k, v := range forced {
   270  		// respect user specified options but print warning!
   271  		cd.Options[k] = defaults.String(cd.Options[k], v)
   272  		if cd.Options[k] != v { // when user-defined option exists
   273  			log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.Options[k])
   274  			log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.Options[k])
   275  		} // or override with `cd.Options[k] = v`?
   276  		if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
   277  			log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
   278  		} // or fix user specified url?
   279  	}
   280  }