github.com/rjgonzale/pop/v5@v5.1.3-dev/dialect_mysql.go (about)

     1  package pop
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"os/exec"
     9  	"strings"
    10  
    11  	// Load MySQL Go driver
    12  	_mysql "github.com/go-sql-driver/mysql"
    13  	"github.com/gobuffalo/fizz"
    14  	"github.com/gobuffalo/fizz/translators"
    15  	"github.com/gobuffalo/pop/v5/columns"
    16  	"github.com/gobuffalo/pop/v5/internal/defaults"
    17  	"github.com/gobuffalo/pop/v5/logging"
    18  	"github.com/pkg/errors"
    19  )
    20  
    21  const nameMySQL = "mysql"
    22  const hostMySQL = "localhost"
    23  const portMySQL = "3306"
    24  
    25  func init() {
    26  	AvailableDialects = append(AvailableDialects, nameMySQL)
    27  	urlParser[nameMySQL] = urlParserMySQL
    28  	finalizer[nameMySQL] = finalizerMySQL
    29  	newConnection[nameMySQL] = newMySQL
    30  }
    31  
    32  var _ dialect = &mysql{}
    33  
    34  type mysql struct {
    35  	commonDialect
    36  }
    37  
    38  func (m *mysql) Name() string {
    39  	return nameMySQL
    40  }
    41  
    42  func (m *mysql) DefaultDriver() string {
    43  	return nameMySQL
    44  }
    45  
    46  func (mysql) Quote(key string) string {
    47  	return fmt.Sprintf("`%s`", key)
    48  }
    49  
    50  func (m *mysql) Details() *ConnectionDetails {
    51  	return m.ConnectionDetails
    52  }
    53  
    54  func (m *mysql) URL() string {
    55  	cd := m.ConnectionDetails
    56  	if cd.URL != "" {
    57  		return strings.TrimPrefix(cd.URL, "mysql://")
    58  	}
    59  
    60  	user := fmt.Sprintf("%s:%s@", cd.User, cd.Password)
    61  	user = strings.Replace(user, ":@", "@", 1)
    62  	if user == "@" || strings.HasPrefix(user, ":") {
    63  		user = ""
    64  	}
    65  
    66  	addr := fmt.Sprintf("(%s:%s)", cd.Host, cd.Port)
    67  	// in case of unix domain socket, tricky.
    68  	// it is better to check Host is not valid inet address or has '/'.
    69  	if cd.Port == "socket" {
    70  		addr = fmt.Sprintf("unix(%s)", cd.Host)
    71  	}
    72  
    73  	s := "%s%s/%s?%s"
    74  	return fmt.Sprintf(s, user, addr, cd.Database, cd.OptionsString(""))
    75  }
    76  
    77  func (m *mysql) urlWithoutDb() string {
    78  	cd := m.ConnectionDetails
    79  	return strings.Replace(m.URL(), "/"+cd.Database+"?", "/?", 1)
    80  }
    81  
    82  func (m *mysql) MigrationURL() string {
    83  	return m.URL()
    84  }
    85  
    86  func (m *mysql) Create(s store, model *Model, cols columns.Columns) error {
    87  	return errors.Wrap(genericCreate(s, model, cols, m), "mysql create")
    88  }
    89  
    90  func (m *mysql) Update(s store, model *Model, cols columns.Columns) error {
    91  	return errors.Wrap(genericUpdate(s, model, cols, m), "mysql update")
    92  }
    93  
    94  func (m *mysql) Destroy(s store, model *Model) error {
    95  	return errors.Wrap(genericDestroy(s, model, m), "mysql destroy")
    96  }
    97  
    98  func (m *mysql) SelectOne(s store, model *Model, query Query) error {
    99  	return errors.Wrap(genericSelectOne(s, model, query), "mysql select one")
   100  }
   101  
   102  func (m *mysql) SelectMany(s store, models *Model, query Query) error {
   103  	return errors.Wrap(genericSelectMany(s, models, query), "mysql select many")
   104  }
   105  
   106  // CreateDB creates a new database, from the given connection credentials
   107  func (m *mysql) CreateDB() error {
   108  	deets := m.ConnectionDetails
   109  	db, err := sql.Open(deets.Dialect, m.urlWithoutDb())
   110  	if err != nil {
   111  		return errors.Wrapf(err, "error creating MySQL database %s", deets.Database)
   112  	}
   113  	defer db.Close()
   114  	charset := defaults.String(deets.Options["charset"], "utf8mb4")
   115  	encoding := defaults.String(deets.Options["collation"], "utf8mb4_general_ci")
   116  	query := fmt.Sprintf("CREATE DATABASE `%s` DEFAULT CHARSET `%s` DEFAULT COLLATE `%s`", deets.Database, charset, encoding)
   117  	log(logging.SQL, query)
   118  
   119  	_, err = db.Exec(query)
   120  	if err != nil {
   121  		return errors.Wrapf(err, "error creating MySQL database %s", deets.Database)
   122  	}
   123  
   124  	log(logging.Info, "created database %s", deets.Database)
   125  	return nil
   126  }
   127  
   128  // DropDB drops an existing database, from the given connection credentials
   129  func (m *mysql) DropDB() error {
   130  	deets := m.ConnectionDetails
   131  	db, err := sql.Open(deets.Dialect, m.urlWithoutDb())
   132  	if err != nil {
   133  		return errors.Wrapf(err, "error dropping MySQL database %s", deets.Database)
   134  	}
   135  	defer db.Close()
   136  	query := fmt.Sprintf("DROP DATABASE `%s`", deets.Database)
   137  	log(logging.SQL, query)
   138  
   139  	_, err = db.Exec(query)
   140  	if err != nil {
   141  		return errors.Wrapf(err, "error dropping MySQL database %s", deets.Database)
   142  	}
   143  
   144  	log(logging.Info, "dropped database %s", deets.Database)
   145  	return nil
   146  }
   147  
   148  func (m *mysql) TranslateSQL(sql string) string {
   149  	return sql
   150  }
   151  
   152  func (m *mysql) FizzTranslator() fizz.Translator {
   153  	t := translators.NewMySQL(m.URL(), m.Details().Database)
   154  	return t
   155  }
   156  
   157  func (m *mysql) DumpSchema(w io.Writer) error {
   158  	deets := m.Details()
   159  	cmd := exec.Command("mysqldump", "-d", "-h", deets.Host, "-P", deets.Port, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
   160  	if deets.Port == "socket" {
   161  		cmd = exec.Command("mysqldump", "-d", "-S", deets.Host, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
   162  	}
   163  	return genericDumpSchema(deets, cmd, w)
   164  }
   165  
   166  // LoadSchema executes a schema sql file against the configured database.
   167  func (m *mysql) LoadSchema(r io.Reader) error {
   168  	return genericLoadSchema(m.ConnectionDetails, m.MigrationURL(), r)
   169  }
   170  
   171  // TruncateAll truncates all tables for the given connection.
   172  func (m *mysql) TruncateAll(tx *Connection) error {
   173  	var stmts []string
   174  	err := tx.RawQuery(mysqlTruncate, m.Details().Database, tx.MigrationTableName()).All(&stmts)
   175  	if err != nil {
   176  		return err
   177  	}
   178  	if len(stmts) == 0 {
   179  		return nil
   180  	}
   181  
   182  	var qb bytes.Buffer
   183  	// #49: Disable foreign keys before truncation
   184  	qb.WriteString("SET SESSION FOREIGN_KEY_CHECKS = 0; ")
   185  	qb.WriteString(strings.Join(stmts, " "))
   186  	// #49: Re-enable foreign keys after truncation
   187  	qb.WriteString(" SET SESSION FOREIGN_KEY_CHECKS = 1;")
   188  
   189  	return tx.RawQuery(qb.String()).Exec()
   190  }
   191  
   192  func newMySQL(deets *ConnectionDetails) (dialect, error) {
   193  	cd := &mysql{
   194  		commonDialect: commonDialect{ConnectionDetails: deets},
   195  	}
   196  	return cd, nil
   197  }
   198  
   199  func urlParserMySQL(cd *ConnectionDetails) error {
   200  	cfg, err := _mysql.ParseDSN(strings.TrimPrefix(cd.URL, "mysql://"))
   201  	if err != nil {
   202  		return errors.Wrapf(err, "the URL '%s' is not supported by MySQL driver", cd.URL)
   203  	}
   204  
   205  	cd.User = cfg.User
   206  	cd.Password = cfg.Passwd
   207  	cd.Database = cfg.DBName
   208  	if cd.Options == nil { // prevent panic
   209  		cd.Options = make(map[string]string)
   210  	}
   211  	// NOTE: use cfg.Params if want to fill options with full parameters
   212  	cd.Options["collation"] = cfg.Collation
   213  	if cfg.Net == "unix" {
   214  		cd.Port = "socket" // trick. see: `URL()`
   215  		cd.Host = cfg.Addr
   216  	} else {
   217  		tmp := strings.Split(cfg.Addr, ":")
   218  		cd.Host = tmp[0]
   219  		if len(tmp) > 1 {
   220  			cd.Port = tmp[1]
   221  		}
   222  	}
   223  
   224  	return nil
   225  }
   226  
   227  func finalizerMySQL(cd *ConnectionDetails) {
   228  	cd.Host = defaults.String(cd.Host, hostMySQL)
   229  	cd.Port = defaults.String(cd.Port, portMySQL)
   230  
   231  	defs := map[string]string{
   232  		"readTimeout": "3s",
   233  		"collation":   "utf8mb4_general_ci",
   234  	}
   235  	forced := map[string]string{
   236  		"parseTime":       "true",
   237  		"multiStatements": "true",
   238  	}
   239  
   240  	if cd.Options == nil { // prevent panic
   241  		cd.Options = make(map[string]string)
   242  	}
   243  
   244  	for k, v := range defs {
   245  		cd.Options[k] = defaults.String(cd.Options[k], v)
   246  	}
   247  
   248  	for k, v := range forced {
   249  		// respect user specified options but print warning!
   250  		cd.Options[k] = defaults.String(cd.Options[k], v)
   251  		if cd.Options[k] != v { // when user-defined option exists
   252  			log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.Options[k])
   253  			log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.Options[k])
   254  		} // or override with `cd.Options[k] = v`?
   255  		if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
   256  			log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
   257  		} // or fix user specified url?
   258  	}
   259  }
   260  
   261  const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name <> ? AND table_type <> 'VIEW'"