github.com/duskeagle/pop@v4.10.1-0.20190417200916-92f2b794aab5+incompatible/dialect_mysql.go (about)

     1  package pop
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql"
     7  	"fmt"
     8  	"io"
     9  	"os/exec"
    10  	"strings"
    11  
    12  	// Load MySQL Go driver
    13  	_mysql "github.com/go-sql-driver/mysql"
    14  	"github.com/gobuffalo/fizz"
    15  	"github.com/gobuffalo/fizz/translators"
    16  	"github.com/gobuffalo/pop/columns"
    17  	"github.com/gobuffalo/pop/logging"
    18  	"github.com/markbates/going/defaults"
    19  	"github.com/markbates/oncer"
    20  	"github.com/pkg/errors"
    21  )
    22  
    23  const nameMySQL = "mysql"
    24  const hostMySQL = "localhost"
    25  const portMySQL = "3306"
    26  
    27  func init() {
    28  	AvailableDialects = append(AvailableDialects, nameMySQL)
    29  	urlParser[nameMySQL] = urlParserMySQL
    30  	finalizer[nameMySQL] = finalizerMySQL
    31  	newConnection[nameMySQL] = newMySQL
    32  }
    33  
    34  var _ dialect = &mysql{}
    35  
    36  type mysql struct {
    37  	commonDialect
    38  }
    39  
    40  func (m *mysql) Name() string {
    41  	return nameMySQL
    42  }
    43  
    44  func (mysql) Quote(key string) string {
    45  	return fmt.Sprintf("`%s`", key)
    46  }
    47  
    48  func (m *mysql) Details() *ConnectionDetails {
    49  	return m.ConnectionDetails
    50  }
    51  
    52  func (m *mysql) URL() string {
    53  	cd := m.ConnectionDetails
    54  	if cd.URL != "" {
    55  		return strings.TrimPrefix(cd.URL, "mysql://")
    56  	}
    57  
    58  	user := fmt.Sprintf("%s:%s@", cd.User, cd.Password)
    59  	user = strings.Replace(user, ":@", "@", 1)
    60  	if user == "@" || strings.HasPrefix(user, ":") {
    61  		user = ""
    62  	}
    63  
    64  	addr := fmt.Sprintf("(%s:%s)", cd.Host, cd.Port)
    65  	// in case of unix domain socket, tricky.
    66  	// it is better to check Host is not valid inet address or has '/'.
    67  	if cd.Port == "socket" {
    68  		addr = fmt.Sprintf("unix(%s)", cd.Host)
    69  	}
    70  
    71  	s := "%s%s/%s?%s"
    72  	return fmt.Sprintf(s, user, addr, cd.Database, cd.OptionsString(""))
    73  }
    74  
    75  func (m *mysql) urlWithoutDb() string {
    76  	cd := m.ConnectionDetails
    77  	return strings.Replace(m.URL(), "/"+cd.Database+"?", "/?", 1)
    78  }
    79  
    80  func (m *mysql) MigrationURL() string {
    81  	return m.URL()
    82  }
    83  
    84  func (m *mysql) Create(s store, model *Model, cols columns.Columns) error {
    85  	return errors.Wrap(genericCreate(s, model, cols), "mysql create")
    86  }
    87  
    88  func (m *mysql) Update(s store, model *Model, cols columns.Columns) error {
    89  	return errors.Wrap(genericUpdate(s, model, cols), "mysql update")
    90  }
    91  
    92  func (m *mysql) Destroy(s store, model *Model) error {
    93  	return errors.Wrap(genericDestroy(s, model), "mysql destroy")
    94  }
    95  
    96  func (m *mysql) SelectOne(s store, model *Model, query Query) error {
    97  	return errors.Wrap(genericSelectOne(s, model, query), "mysql select one")
    98  }
    99  
   100  func (m *mysql) SelectMany(s store, models *Model, query Query) error {
   101  	return errors.Wrap(genericSelectMany(s, models, query), "mysql select many")
   102  }
   103  
   104  // CreateDB creates a new database, from the given connection credentials
   105  func (m *mysql) CreateDB() error {
   106  	deets := m.ConnectionDetails
   107  	db, err := sql.Open(deets.Dialect, m.urlWithoutDb())
   108  	if err != nil {
   109  		return errors.Wrapf(err, "error creating MySQL database %s", deets.Database)
   110  	}
   111  	defer db.Close()
   112  	encoding := defaults.String(deets.Options["collation"], "utf8mb4_general_ci")
   113  	query := fmt.Sprintf("CREATE DATABASE `%s` DEFAULT COLLATE `%s`", deets.Database, encoding)
   114  	log(logging.SQL, query)
   115  
   116  	_, err = db.Exec(query)
   117  	if err != nil {
   118  		return errors.Wrapf(err, "error creating MySQL database %s", deets.Database)
   119  	}
   120  
   121  	log(logging.Info, "created database %s", deets.Database)
   122  	return nil
   123  }
   124  
   125  // DropDB drops an existing database, from the given connection credentials
   126  func (m *mysql) DropDB() error {
   127  	deets := m.ConnectionDetails
   128  	db, err := sql.Open(deets.Dialect, m.urlWithoutDb())
   129  	if err != nil {
   130  		return errors.Wrapf(err, "error dropping MySQL database %s", deets.Database)
   131  	}
   132  	defer db.Close()
   133  	query := fmt.Sprintf("DROP DATABASE `%s`", deets.Database)
   134  	log(logging.SQL, query)
   135  
   136  	_, err = db.Exec(query)
   137  	if err != nil {
   138  		return errors.Wrapf(err, "error dropping MySQL database %s", deets.Database)
   139  	}
   140  
   141  	log(logging.Info, "dropped database %s", deets.Database)
   142  	return nil
   143  }
   144  
   145  func (m *mysql) TranslateSQL(sql string) string {
   146  	return sql
   147  }
   148  
   149  func (m *mysql) FizzTranslator() fizz.Translator {
   150  	t := translators.NewMySQL(m.URL(), m.Details().Database)
   151  	return t
   152  }
   153  
   154  func (m *mysql) DumpSchema(w io.Writer) error {
   155  	deets := m.Details()
   156  	cmd := exec.Command("mysqldump", "-d", "-h", deets.Host, "-P", deets.Port, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
   157  	if deets.Port == "socket" {
   158  		cmd = exec.Command("mysqldump", "-d", "-S", deets.Host, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
   159  	}
   160  	return genericDumpSchema(deets, cmd, w)
   161  }
   162  
   163  // LoadSchema executes a schema sql file against the configured database.
   164  func (m *mysql) LoadSchema(r io.Reader) error {
   165  	return genericLoadSchema(m.ConnectionDetails, m.MigrationURL(), r)
   166  }
   167  
   168  // TruncateAll truncates all tables for the given connection.
   169  func (m *mysql) TruncateAll(tx *Connection) error {
   170  	var stmts []string
   171  	err := tx.RawQuery(mysqlTruncate, m.Details().Database).All(context.TODO(), &stmts)
   172  	if err != nil {
   173  		return err
   174  	}
   175  	if len(stmts) == 0 {
   176  		return nil
   177  	}
   178  
   179  	var qb bytes.Buffer
   180  	// #49: Disable foreign keys before truncation
   181  	qb.WriteString("SET SESSION FOREIGN_KEY_CHECKS = 0; ")
   182  	qb.WriteString(strings.Join(stmts, " "))
   183  	// #49: Re-enable foreign keys after truncation
   184  	qb.WriteString(" SET SESSION FOREIGN_KEY_CHECKS = 1;")
   185  
   186  	return tx.RawQuery(qb.String()).Exec()
   187  }
   188  
   189  func newMySQL(deets *ConnectionDetails) (dialect, error) {
   190  	cd := &mysql{
   191  		commonDialect: commonDialect{ConnectionDetails: deets},
   192  	}
   193  	return cd, nil
   194  }
   195  
   196  func urlParserMySQL(cd *ConnectionDetails) error {
   197  	cfg, err := _mysql.ParseDSN(strings.TrimPrefix(cd.URL, "mysql://"))
   198  	if err != nil {
   199  		return errors.Wrapf(err, "the URL '%s' is not supported by MySQL driver", cd.URL)
   200  	}
   201  
   202  	cd.User = cfg.User
   203  	cd.Password = cfg.Passwd
   204  	cd.Database = cfg.DBName
   205  	if cd.Options == nil { // prevent panic
   206  		cd.Options = make(map[string]string)
   207  	}
   208  	// NOTE: use cfg.Params if want to fill options with full parameters
   209  	cd.Options["collation"] = cfg.Collation
   210  	if cfg.Net == "unix" {
   211  		cd.Port = "socket" // trick. see: `URL()`
   212  		cd.Host = cfg.Addr
   213  	} else {
   214  		tmp := strings.Split(cfg.Addr, ":")
   215  		cd.Host = tmp[0]
   216  		if len(tmp) > 1 {
   217  			cd.Port = tmp[1]
   218  		}
   219  	}
   220  
   221  	return nil
   222  }
   223  
   224  func finalizerMySQL(cd *ConnectionDetails) {
   225  	cd.Host = defaults.String(cd.Host, hostMySQL)
   226  	cd.Port = defaults.String(cd.Port, portMySQL)
   227  
   228  	defs := map[string]string{
   229  		"readTimeout": "3s",
   230  		"collation":   "utf8mb4_general_ci",
   231  	}
   232  	forced := map[string]string{
   233  		"parseTime":       "true",
   234  		"multiStatements": "true",
   235  	}
   236  
   237  	if cd.Options == nil { // prevent panic
   238  		cd.Options = make(map[string]string)
   239  	}
   240  
   241  	for k, v := range defs {
   242  		cd.Options[k] = defaults.String(cd.Options[k], v)
   243  	}
   244  
   245  	for k, v := range forced {
   246  		// respect user specified options but print warning!
   247  		cd.Options[k] = defaults.String(cd.Options[k], v)
   248  		if cd.Options[k] != v { // when user-defined option exists
   249  			log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.Options[k])
   250  			log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.Options[k])
   251  		} // or override with `cd.Options[k] = v`?
   252  		if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
   253  			log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
   254  		} // or fix user specified url?
   255  	}
   256  
   257  	if cd.Encoding != "" {
   258  		//! DEPRECATED, 2018-11-06
   259  		// when user still uses `encoding:` in database.yml
   260  		oncer.Deprecate(0, "Encoding", "use options.collation")
   261  		cd.Options["collation"] = cd.Encoding
   262  	}
   263  }
   264  
   265  const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_type <> 'VIEW'"