github.com/dkishere/pop/v6@v6.103.1/dialect_mysql.go (about)

     1  package pop
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"os/exec"
     8  	"strings"
     9  
    10  	"github.com/dkishere/pop/v6/columns"
    11  	"github.com/dkishere/pop/v6/internal/defaults"
    12  	"github.com/dkishere/pop/v6/logging"
    13  	_mysql "github.com/go-sql-driver/mysql" // Load MySQL Go driver
    14  	"github.com/gobuffalo/fizz"
    15  	"github.com/gobuffalo/fizz/translators"
    16  )
    17  
    18  const nameMySQL = "mysql"
    19  const hostMySQL = "localhost"
    20  const portMySQL = "3306"
    21  
    22  func init() {
    23  	AvailableDialects = append(AvailableDialects, nameMySQL)
    24  	urlParser[nameMySQL] = urlParserMySQL
    25  	finalizer[nameMySQL] = finalizerMySQL
    26  	newConnection[nameMySQL] = newMySQL
    27  }
    28  
    29  var _ dialect = &mysql{}
    30  
    31  type mysql struct {
    32  	commonDialect
    33  }
    34  
    35  func (m *mysql) Name() string {
    36  	return nameMySQL
    37  }
    38  
    39  func (m *mysql) DefaultDriver() string {
    40  	return nameMySQL
    41  }
    42  
    43  func (mysql) Quote(key string) string {
    44  	return fmt.Sprintf("`%s`", key)
    45  }
    46  
    47  func (m *mysql) Details() *ConnectionDetails {
    48  	return m.ConnectionDetails
    49  }
    50  
    51  func (m *mysql) URL() string {
    52  	cd := m.ConnectionDetails
    53  	if cd.URL != "" {
    54  		return strings.TrimPrefix(cd.URL, "mysql://")
    55  	}
    56  
    57  	user := fmt.Sprintf("%s:%s@", cd.User, cd.Password)
    58  	user = strings.Replace(user, ":@", "@", 1)
    59  	if user == "@" || strings.HasPrefix(user, ":") {
    60  		user = ""
    61  	}
    62  
    63  	addr := fmt.Sprintf("(%s:%s)", cd.Host, cd.Port)
    64  	// in case of unix domain socket, tricky.
    65  	// it is better to check Host is not valid inet address or has '/'.
    66  	if cd.Port == "socket" {
    67  		addr = fmt.Sprintf("unix(%s)", cd.Host)
    68  	}
    69  
    70  	s := "%s%s/%s?%s"
    71  	return fmt.Sprintf(s, user, addr, cd.Database, cd.OptionsString(""))
    72  }
    73  
    74  func (m *mysql) urlWithoutDb() string {
    75  	cd := m.ConnectionDetails
    76  	return strings.Replace(m.URL(), "/"+cd.Database+"?", "/?", 1)
    77  }
    78  
    79  func (m *mysql) MigrationURL() string {
    80  	return m.URL()
    81  }
    82  
    83  func (m *mysql) Create(s store, model *Model, cols columns.Columns) error {
    84  	if err := genericCreate(s, model, cols, m); err != nil {
    85  		return fmt.Errorf("mysql create: %w", err)
    86  	}
    87  	return nil
    88  }
    89  
    90  func (m *mysql) Update(s store, model *Model, cols columns.Columns) error {
    91  	if err := genericUpdate(s, model, cols, m); err != nil {
    92  		return fmt.Errorf("mysql update: %w", err)
    93  	}
    94  	return nil
    95  }
    96  
    97  func (m *mysql) Destroy(s store, model *Model) error {
    98  	stmt := fmt.Sprintf("DELETE FROM %s  WHERE %s = ?", m.Quote(model.TableName()), model.IDField())
    99  	_, err := genericExec(s, stmt, model.ID())
   100  	if err != nil {
   101  		return fmt.Errorf("mysql destroy: %w", err)
   102  	}
   103  	return nil
   104  }
   105  
   106  func (m *mysql) Delete(s store, model *Model, query Query) error {
   107  	sb := query.toSQLBuilder(model)
   108  
   109  	sql := fmt.Sprintf("DELETE FROM %s", m.Quote(model.TableName()))
   110  	sql = sb.buildWhereClauses(sql)
   111  
   112  	_, err := genericExec(s, sql, sb.Args()...)
   113  	if err != nil {
   114  		return fmt.Errorf("mysql delete: %w", err)
   115  	}
   116  	return nil
   117  }
   118  
   119  func (m *mysql) SelectOne(s store, model *Model, query Query) error {
   120  	if err := genericSelectOne(s, model, query); err != nil {
   121  		return fmt.Errorf("mysql select one: %w", err)
   122  	}
   123  	return nil
   124  }
   125  
   126  func (m *mysql) SelectMany(s store, models *Model, query Query) error {
   127  	if err := genericSelectMany(s, models, query); err != nil {
   128  		return fmt.Errorf("mysql select many: %w", err)
   129  	}
   130  	return nil
   131  }
   132  
   133  // CreateDB creates a new database, from the given connection credentials
   134  func (m *mysql) CreateDB() error {
   135  	deets := m.ConnectionDetails
   136  	db, err := openPotentiallyInstrumentedConnection(m, m.urlWithoutDb())
   137  	if err != nil {
   138  		return fmt.Errorf("error creating MySQL database %s: %w", deets.Database, err)
   139  	}
   140  	defer db.Close()
   141  	charset := defaults.String(deets.Options["charset"], "utf8mb4")
   142  	encoding := defaults.String(deets.Options["collation"], "utf8mb4_general_ci")
   143  	query := fmt.Sprintf("CREATE DATABASE `%s` DEFAULT CHARSET `%s` DEFAULT COLLATE `%s`", deets.Database, charset, encoding)
   144  	log(logging.SQL, query)
   145  
   146  	_, err = db.Exec(query)
   147  	if err != nil {
   148  		return fmt.Errorf("error creating MySQL database %s: %w", deets.Database, err)
   149  	}
   150  
   151  	log(logging.Info, "created database %s", deets.Database)
   152  	return nil
   153  }
   154  
   155  // DropDB drops an existing database, from the given connection credentials
   156  func (m *mysql) DropDB() error {
   157  	deets := m.ConnectionDetails
   158  	db, err := openPotentiallyInstrumentedConnection(m, m.urlWithoutDb())
   159  	if err != nil {
   160  		return fmt.Errorf("error dropping MySQL database %s: %w", deets.Database, err)
   161  	}
   162  	defer db.Close()
   163  	query := fmt.Sprintf("DROP DATABASE `%s`", deets.Database)
   164  	log(logging.SQL, query)
   165  
   166  	_, err = db.Exec(query)
   167  	if err != nil {
   168  		return fmt.Errorf("error dropping MySQL database %s: %w", deets.Database, err)
   169  	}
   170  
   171  	log(logging.Info, "dropped database %s", deets.Database)
   172  	return nil
   173  }
   174  
   175  func (m *mysql) TranslateSQL(sql string) string {
   176  	return sql
   177  }
   178  
   179  func (m *mysql) FizzTranslator() fizz.Translator {
   180  	t := translators.NewMySQL(m.URL(), m.Details().Database)
   181  	return t
   182  }
   183  
   184  func (m *mysql) DumpSchema(w io.Writer) error {
   185  	deets := m.Details()
   186  	// Github CI is currently using mysql:5.7 but the mysqldump version doesn't seem to match
   187  	cmd := exec.Command("mysqldump", "--column-statistics=0", "-d", "-h", deets.Host, "-P", deets.Port, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
   188  	if deets.Port == "socket" {
   189  		cmd = exec.Command("mysqldump", "--column-statistics=0", "-d", "-S", deets.Host, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
   190  	}
   191  	return genericDumpSchema(deets, cmd, w)
   192  }
   193  
   194  // LoadSchema executes a schema sql file against the configured database.
   195  func (m *mysql) LoadSchema(r io.Reader) error {
   196  	return genericLoadSchema(m, r)
   197  }
   198  
   199  // TruncateAll truncates all tables for the given connection.
   200  func (m *mysql) TruncateAll(tx *Connection) error {
   201  	var stmts []string
   202  	err := tx.RawQuery(mysqlTruncate, m.Details().Database, tx.MigrationTableName()).All(&stmts)
   203  	if err != nil {
   204  		return err
   205  	}
   206  	if len(stmts) == 0 {
   207  		return nil
   208  	}
   209  
   210  	var qb bytes.Buffer
   211  	// #49: Disable foreign keys before truncation
   212  	qb.WriteString("SET SESSION FOREIGN_KEY_CHECKS = 0; ")
   213  	qb.WriteString(strings.Join(stmts, " "))
   214  	// #49: Re-enable foreign keys after truncation
   215  	qb.WriteString(" SET SESSION FOREIGN_KEY_CHECKS = 1;")
   216  
   217  	return tx.RawQuery(qb.String()).Exec()
   218  }
   219  
   220  func newMySQL(deets *ConnectionDetails) (dialect, error) {
   221  	cd := &mysql{
   222  		commonDialect: commonDialect{ConnectionDetails: deets},
   223  	}
   224  	return cd, nil
   225  }
   226  
   227  func urlParserMySQL(cd *ConnectionDetails) error {
   228  	cfg, err := _mysql.ParseDSN(strings.TrimPrefix(cd.URL, "mysql://"))
   229  	if err != nil {
   230  		return fmt.Errorf("the URL '%s' is not supported by MySQL driver: %w", cd.URL, err)
   231  	}
   232  
   233  	cd.User = cfg.User
   234  	cd.Password = cfg.Passwd
   235  	cd.Database = cfg.DBName
   236  	if cd.Options == nil { // prevent panic
   237  		cd.Options = make(map[string]string)
   238  	}
   239  	// NOTE: use cfg.Params if want to fill options with full parameters
   240  	cd.Options["collation"] = cfg.Collation
   241  	if cfg.Net == "unix" {
   242  		cd.Port = "socket" // trick. see: `URL()`
   243  		cd.Host = cfg.Addr
   244  	} else {
   245  		tmp := strings.Split(cfg.Addr, ":")
   246  		cd.Host = tmp[0]
   247  		if len(tmp) > 1 {
   248  			cd.Port = tmp[1]
   249  		}
   250  	}
   251  
   252  	return nil
   253  }
   254  
   255  func finalizerMySQL(cd *ConnectionDetails) {
   256  	cd.Host = defaults.String(cd.Host, hostMySQL)
   257  	cd.Port = defaults.String(cd.Port, portMySQL)
   258  
   259  	defs := map[string]string{
   260  		"readTimeout": "3s",
   261  		"collation":   "utf8mb4_general_ci",
   262  	}
   263  	forced := map[string]string{
   264  		"parseTime":       "true",
   265  		"multiStatements": "true",
   266  	}
   267  
   268  	if cd.Options == nil { // prevent panic
   269  		cd.Options = make(map[string]string)
   270  	}
   271  
   272  	for k, v := range defs {
   273  		cd.Options[k] = defaults.String(cd.Options[k], v)
   274  	}
   275  
   276  	for k, v := range forced {
   277  		// respect user specified options but print warning!
   278  		cd.Options[k] = defaults.String(cd.Options[k], v)
   279  		if cd.Options[k] != v { // when user-defined option exists
   280  			log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.Options[k])
   281  			log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.Options[k])
   282  		} // or override with `cd.Options[k] = v`?
   283  		if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
   284  			log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
   285  		} // or fix user specified url?
   286  	}
   287  }
   288  
   289  const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name <> ? AND table_type <> 'VIEW'"