github.com/dolanor/pop@v4.13.0+incompatible/dialect_mysql.go (about)

     1  package pop
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"os/exec"
     9  	"strings"
    10  
    11  	// Load MySQL Go driver
    12  	_mysql "github.com/go-sql-driver/mysql"
    13  	"github.com/gobuffalo/fizz"
    14  	"github.com/gobuffalo/fizz/translators"
    15  	"github.com/gobuffalo/pop/columns"
    16  	"github.com/gobuffalo/pop/internal/defaults"
    17  	"github.com/gobuffalo/pop/internal/oncer"
    18  	"github.com/gobuffalo/pop/logging"
    19  	"github.com/pkg/errors"
    20  )
    21  
    22  const nameMySQL = "mysql"
    23  const hostMySQL = "localhost"
    24  const portMySQL = "3306"
    25  
    26  func init() {
    27  	AvailableDialects = append(AvailableDialects, nameMySQL)
    28  	urlParser[nameMySQL] = urlParserMySQL
    29  	finalizer[nameMySQL] = finalizerMySQL
    30  	newConnection[nameMySQL] = newMySQL
    31  }
    32  
    33  var _ dialect = &mysql{}
    34  
    35  type mysql struct {
    36  	commonDialect
    37  }
    38  
    39  func (m *mysql) Name() string {
    40  	return nameMySQL
    41  }
    42  
    43  func (mysql) Quote(key string) string {
    44  	return fmt.Sprintf("`%s`", key)
    45  }
    46  
    47  func (m *mysql) Details() *ConnectionDetails {
    48  	return m.ConnectionDetails
    49  }
    50  
    51  func (m *mysql) URL() string {
    52  	cd := m.ConnectionDetails
    53  	if cd.URL != "" {
    54  		return strings.TrimPrefix(cd.URL, "mysql://")
    55  	}
    56  
    57  	user := fmt.Sprintf("%s:%s@", cd.User, cd.Password)
    58  	user = strings.Replace(user, ":@", "@", 1)
    59  	if user == "@" || strings.HasPrefix(user, ":") {
    60  		user = ""
    61  	}
    62  
    63  	addr := fmt.Sprintf("(%s:%s)", cd.Host, cd.Port)
    64  	// in case of unix domain socket, tricky.
    65  	// it is better to check Host is not valid inet address or has '/'.
    66  	if cd.Port == "socket" {
    67  		addr = fmt.Sprintf("unix(%s)", cd.Host)
    68  	}
    69  
    70  	s := "%s%s/%s?%s"
    71  	return fmt.Sprintf(s, user, addr, cd.Database, cd.OptionsString(""))
    72  }
    73  
    74  func (m *mysql) urlWithoutDb() string {
    75  	cd := m.ConnectionDetails
    76  	return strings.Replace(m.URL(), "/"+cd.Database+"?", "/?", 1)
    77  }
    78  
    79  func (m *mysql) MigrationURL() string {
    80  	return m.URL()
    81  }
    82  
    83  func (m *mysql) Create(s store, model *Model, cols columns.Columns) error {
    84  	return errors.Wrap(genericCreate(s, model, cols, m), "mysql create")
    85  }
    86  
    87  func (m *mysql) Update(s store, model *Model, cols columns.Columns) error {
    88  	return errors.Wrap(genericUpdate(s, model, cols, m), "mysql update")
    89  }
    90  
    91  func (m *mysql) Destroy(s store, model *Model) error {
    92  	return errors.Wrap(genericDestroy(s, model, m), "mysql destroy")
    93  }
    94  
    95  func (m *mysql) SelectOne(s store, model *Model, query Query) error {
    96  	return errors.Wrap(genericSelectOne(s, model, query), "mysql select one")
    97  }
    98  
    99  func (m *mysql) SelectMany(s store, models *Model, query Query) error {
   100  	return errors.Wrap(genericSelectMany(s, models, query), "mysql select many")
   101  }
   102  
   103  // CreateDB creates a new database, from the given connection credentials
   104  func (m *mysql) CreateDB() error {
   105  	deets := m.ConnectionDetails
   106  	db, err := sql.Open(deets.Dialect, m.urlWithoutDb())
   107  	if err != nil {
   108  		return errors.Wrapf(err, "error creating MySQL database %s", deets.Database)
   109  	}
   110  	defer db.Close()
   111  	encoding := defaults.String(deets.Options["collation"], "utf8mb4_general_ci")
   112  	query := fmt.Sprintf("CREATE DATABASE `%s` DEFAULT COLLATE `%s`", deets.Database, encoding)
   113  	log(logging.SQL, query)
   114  
   115  	_, err = db.Exec(query)
   116  	if err != nil {
   117  		return errors.Wrapf(err, "error creating MySQL database %s", deets.Database)
   118  	}
   119  
   120  	log(logging.Info, "created database %s", deets.Database)
   121  	return nil
   122  }
   123  
   124  // DropDB drops an existing database, from the given connection credentials
   125  func (m *mysql) DropDB() error {
   126  	deets := m.ConnectionDetails
   127  	db, err := sql.Open(deets.Dialect, m.urlWithoutDb())
   128  	if err != nil {
   129  		return errors.Wrapf(err, "error dropping MySQL database %s", deets.Database)
   130  	}
   131  	defer db.Close()
   132  	query := fmt.Sprintf("DROP DATABASE `%s`", deets.Database)
   133  	log(logging.SQL, query)
   134  
   135  	_, err = db.Exec(query)
   136  	if err != nil {
   137  		return errors.Wrapf(err, "error dropping MySQL database %s", deets.Database)
   138  	}
   139  
   140  	log(logging.Info, "dropped database %s", deets.Database)
   141  	return nil
   142  }
   143  
   144  func (m *mysql) TranslateSQL(sql string) string {
   145  	return sql
   146  }
   147  
   148  func (m *mysql) FizzTranslator() fizz.Translator {
   149  	t := translators.NewMySQL(m.URL(), m.Details().Database)
   150  	return t
   151  }
   152  
   153  func (m *mysql) DumpSchema(w io.Writer) error {
   154  	deets := m.Details()
   155  	cmd := exec.Command("mysqldump", "-d", "-h", deets.Host, "-P", deets.Port, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
   156  	if deets.Port == "socket" {
   157  		cmd = exec.Command("mysqldump", "-d", "-S", deets.Host, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
   158  	}
   159  	return genericDumpSchema(deets, cmd, w)
   160  }
   161  
   162  // LoadSchema executes a schema sql file against the configured database.
   163  func (m *mysql) LoadSchema(r io.Reader) error {
   164  	return genericLoadSchema(m.ConnectionDetails, m.MigrationURL(), r)
   165  }
   166  
   167  // TruncateAll truncates all tables for the given connection.
   168  func (m *mysql) TruncateAll(tx *Connection) error {
   169  	var stmts []string
   170  	err := tx.RawQuery(mysqlTruncate, m.Details().Database).All(&stmts)
   171  	if err != nil {
   172  		return err
   173  	}
   174  	if len(stmts) == 0 {
   175  		return nil
   176  	}
   177  
   178  	var qb bytes.Buffer
   179  	// #49: Disable foreign keys before truncation
   180  	qb.WriteString("SET SESSION FOREIGN_KEY_CHECKS = 0; ")
   181  	qb.WriteString(strings.Join(stmts, " "))
   182  	// #49: Re-enable foreign keys after truncation
   183  	qb.WriteString(" SET SESSION FOREIGN_KEY_CHECKS = 1;")
   184  
   185  	return tx.RawQuery(qb.String()).Exec()
   186  }
   187  
   188  func newMySQL(deets *ConnectionDetails) (dialect, error) {
   189  	cd := &mysql{
   190  		commonDialect: commonDialect{ConnectionDetails: deets},
   191  	}
   192  	return cd, nil
   193  }
   194  
   195  func urlParserMySQL(cd *ConnectionDetails) error {
   196  	cfg, err := _mysql.ParseDSN(strings.TrimPrefix(cd.URL, "mysql://"))
   197  	if err != nil {
   198  		return errors.Wrapf(err, "the URL '%s' is not supported by MySQL driver", cd.URL)
   199  	}
   200  
   201  	cd.User = cfg.User
   202  	cd.Password = cfg.Passwd
   203  	cd.Database = cfg.DBName
   204  	if cd.Options == nil { // prevent panic
   205  		cd.Options = make(map[string]string)
   206  	}
   207  	// NOTE: use cfg.Params if want to fill options with full parameters
   208  	cd.Options["collation"] = cfg.Collation
   209  	if cfg.Net == "unix" {
   210  		cd.Port = "socket" // trick. see: `URL()`
   211  		cd.Host = cfg.Addr
   212  	} else {
   213  		tmp := strings.Split(cfg.Addr, ":")
   214  		cd.Host = tmp[0]
   215  		if len(tmp) > 1 {
   216  			cd.Port = tmp[1]
   217  		}
   218  	}
   219  
   220  	return nil
   221  }
   222  
   223  func finalizerMySQL(cd *ConnectionDetails) {
   224  	cd.Host = defaults.String(cd.Host, hostMySQL)
   225  	cd.Port = defaults.String(cd.Port, portMySQL)
   226  
   227  	defs := map[string]string{
   228  		"readTimeout": "3s",
   229  		"collation":   "utf8mb4_general_ci",
   230  	}
   231  	forced := map[string]string{
   232  		"parseTime":       "true",
   233  		"multiStatements": "true",
   234  	}
   235  
   236  	if cd.Options == nil { // prevent panic
   237  		cd.Options = make(map[string]string)
   238  	}
   239  
   240  	for k, v := range defs {
   241  		cd.Options[k] = defaults.String(cd.Options[k], v)
   242  	}
   243  
   244  	for k, v := range forced {
   245  		// respect user specified options but print warning!
   246  		cd.Options[k] = defaults.String(cd.Options[k], v)
   247  		if cd.Options[k] != v { // when user-defined option exists
   248  			log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.Options[k])
   249  			log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.Options[k])
   250  		} // or override with `cd.Options[k] = v`?
   251  		if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
   252  			log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
   253  		} // or fix user specified url?
   254  	}
   255  
   256  	if cd.Encoding != "" {
   257  		//! DEPRECATED, 2018-11-06
   258  		// when user still uses `encoding:` in database.yml
   259  		oncer.Deprecate(0, "Encoding", "use options.collation")
   260  		cd.Options["collation"] = cd.Encoding
   261  	}
   262  }
   263  
   264  const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_type <> 'VIEW'"