github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/dialect_mysql.go (about)

     1  package pop
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"os/exec"
     8  	"regexp"
     9  	"strings"
    10  
    11  	"github.com/Accefy/pop/internal/defaults"
    12  	_mysql "github.com/go-sql-driver/mysql" // Load MySQL Go driver
    13  	"github.com/gobuffalo/fizz"
    14  	"github.com/gobuffalo/fizz/translators"
    15  	"github.com/gobuffalo/pop/v6/columns"
    16  	"github.com/gobuffalo/pop/v6/logging"
    17  	"github.com/jmoiron/sqlx"
    18  )
    19  
    20  const nameMySQL = "mysql"
    21  const hostMySQL = "localhost"
    22  const portMySQL = "3306"
    23  
    24  func init() {
    25  	AvailableDialects = append(AvailableDialects, nameMySQL)
    26  	urlParser[nameMySQL] = urlParserMySQL
    27  	finalizer[nameMySQL] = finalizerMySQL
    28  	newConnection[nameMySQL] = newMySQL
    29  }
    30  
    31  var _ dialect = &mysql{}
    32  
    33  type mysql struct {
    34  	commonDialect
    35  }
    36  
    37  func (m *mysql) Name() string {
    38  	return nameMySQL
    39  }
    40  
    41  func (m *mysql) DefaultDriver() string {
    42  	return nameMySQL
    43  }
    44  
    45  func (mysql) Quote(key string) string {
    46  	return fmt.Sprintf("`%s`", key)
    47  }
    48  
    49  func (m *mysql) Details() *ConnectionDetails {
    50  	return m.ConnectionDetails
    51  }
    52  
    53  func (m *mysql) URL() string {
    54  	cd := m.ConnectionDetails
    55  	if cd.URL != "" {
    56  		return strings.TrimPrefix(cd.URL, "mysql://")
    57  	}
    58  
    59  	user := fmt.Sprintf("%s:%s@", cd.User, cd.Password)
    60  	user = strings.Replace(user, ":@", "@", 1)
    61  	if user == "@" || strings.HasPrefix(user, ":") {
    62  		user = ""
    63  	}
    64  
    65  	addr := fmt.Sprintf("(%s:%s)", cd.Host, cd.Port)
    66  	// in case of unix domain socket, tricky.
    67  	// it is better to check Host is not valid inet address or has '/'.
    68  	if cd.Port == "socket" {
    69  		addr = fmt.Sprintf("unix(%s)", cd.Host)
    70  	}
    71  
    72  	s := "%s%s/%s?%s"
    73  	return fmt.Sprintf(s, user, addr, cd.Database, cd.OptionsString(""))
    74  }
    75  
    76  func (m *mysql) urlWithoutDb() string {
    77  	cd := m.ConnectionDetails
    78  	return strings.Replace(m.URL(), "/"+cd.Database+"?", "/?", 1)
    79  }
    80  
    81  func (m *mysql) MigrationURL() string {
    82  	return m.URL()
    83  }
    84  
    85  func (m *mysql) Create(c *Connection, model *Model, cols columns.Columns) error {
    86  	if err := genericCreate(c, model, cols, m); err != nil {
    87  		return fmt.Errorf("mysql create: %w", err)
    88  	}
    89  	return nil
    90  }
    91  
    92  func (m *mysql) Update(c *Connection, model *Model, cols columns.Columns) error {
    93  	if err := genericUpdate(c, model, cols, m); err != nil {
    94  		return fmt.Errorf("mysql update: %w", err)
    95  	}
    96  	return nil
    97  }
    98  
    99  func (m *mysql) UpdateQuery(c *Connection, model *Model, cols columns.Columns, query Query) (int64, error) {
   100  	if n, err := genericUpdateQuery(c, model, cols, m, query, sqlx.QUESTION); err != nil {
   101  		return n, fmt.Errorf("mysql update query: %w", err)
   102  	} else {
   103  		return n, nil
   104  	}
   105  }
   106  
   107  func (m *mysql) Destroy(c *Connection, model *Model) error {
   108  	stmt := fmt.Sprintf("DELETE FROM %s  WHERE %s = ?", m.Quote(model.TableName()), model.IDField())
   109  	_, err := genericExec(c, stmt, model.ID())
   110  	if err != nil {
   111  		return fmt.Errorf("mysql destroy: %w", err)
   112  	}
   113  	return nil
   114  }
   115  
   116  var asRegex = regexp.MustCompile(`\sAS\s\S+`) // exactly " AS non-spaces"
   117  
   118  func (m *mysql) Delete(c *Connection, model *Model, query Query) error {
   119  	sqlQuery, args := query.ToSQL(model)
   120  	// * MySQL does not support table alias for DELETE syntax until 8.0.
   121  	// * Do not generate SQL manually if they may have `WHERE IN`.
   122  	// * Spaces are intentionally added to make it easy to see on the log.
   123  	sqlQuery = asRegex.ReplaceAllString(sqlQuery, "  ")
   124  
   125  	_, err := genericExec(c, sqlQuery, args...)
   126  	return err
   127  }
   128  
   129  func (m *mysql) SelectOne(c *Connection, model *Model, query Query) error {
   130  	if err := genericSelectOne(c, model, query); err != nil {
   131  		return fmt.Errorf("mysql select one: %w", err)
   132  	}
   133  	return nil
   134  }
   135  
   136  func (m *mysql) SelectMany(c *Connection, models *Model, query Query) error {
   137  	if err := genericSelectMany(c, models, query); err != nil {
   138  		return fmt.Errorf("mysql select many: %w", err)
   139  	}
   140  	return nil
   141  }
   142  
   143  // CreateDB creates a new database, from the given connection credentials
   144  func (m *mysql) CreateDB() error {
   145  	deets := m.ConnectionDetails
   146  	db, err := openPotentiallyInstrumentedConnection(m, m.urlWithoutDb())
   147  	if err != nil {
   148  		return fmt.Errorf("error creating MySQL database %s: %w", deets.Database, err)
   149  	}
   150  	defer db.Close()
   151  	charset := defaults.String(deets.option("charset"), "utf8mb4")
   152  	encoding := defaults.String(deets.option("collation"), "utf8mb4_general_ci")
   153  	query := fmt.Sprintf("CREATE DATABASE `%s` DEFAULT CHARSET `%s` DEFAULT COLLATE `%s`", deets.Database, charset, encoding)
   154  	log(logging.SQL, query)
   155  
   156  	_, err = db.Exec(query)
   157  	if err != nil {
   158  		return fmt.Errorf("error creating MySQL database %s: %w", deets.Database, err)
   159  	}
   160  
   161  	log(logging.Info, "created database %s", deets.Database)
   162  	return nil
   163  }
   164  
   165  // DropDB drops an existing database, from the given connection credentials
   166  func (m *mysql) DropDB() error {
   167  	deets := m.ConnectionDetails
   168  	db, err := openPotentiallyInstrumentedConnection(m, m.urlWithoutDb())
   169  	if err != nil {
   170  		return fmt.Errorf("error dropping MySQL database %s: %w", deets.Database, err)
   171  	}
   172  	defer db.Close()
   173  	query := fmt.Sprintf("DROP DATABASE `%s`", deets.Database)
   174  	log(logging.SQL, query)
   175  
   176  	_, err = db.Exec(query)
   177  	if err != nil {
   178  		return fmt.Errorf("error dropping MySQL database %s: %w", deets.Database, err)
   179  	}
   180  
   181  	log(logging.Info, "dropped database %s", deets.Database)
   182  	return nil
   183  }
   184  
   185  func (m *mysql) TranslateSQL(sql string) string {
   186  	return sql
   187  }
   188  
   189  func (m *mysql) FizzTranslator() fizz.Translator {
   190  	t := translators.NewMySQL(m.URL(), m.Details().Database)
   191  	return t
   192  }
   193  
   194  func (m *mysql) DumpSchema(w io.Writer) error {
   195  	deets := m.Details()
   196  	cmd := exec.Command("mysqldump", "-d", "-h", deets.Host, "-P", deets.Port, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
   197  	if deets.Port == "socket" {
   198  		cmd = exec.Command("mysqldump", "-d", "-S", deets.Host, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database)
   199  	}
   200  	return genericDumpSchema(deets, cmd, w)
   201  }
   202  
   203  // LoadSchema executes a schema sql file against the configured database.
   204  func (m *mysql) LoadSchema(r io.Reader) error {
   205  	return genericLoadSchema(m, r)
   206  }
   207  
   208  // TruncateAll truncates all tables for the given connection.
   209  func (m *mysql) TruncateAll(tx *Connection) error {
   210  	var stmts []string
   211  	err := tx.RawQuery(mysqlTruncate, m.Details().Database, tx.MigrationTableName()).All(&stmts)
   212  	if err != nil {
   213  		return err
   214  	}
   215  	if len(stmts) == 0 {
   216  		return nil
   217  	}
   218  
   219  	var qb bytes.Buffer
   220  	// #49: Disable foreign keys before truncation
   221  	qb.WriteString("SET SESSION FOREIGN_KEY_CHECKS = 0; ")
   222  	qb.WriteString(strings.Join(stmts, " "))
   223  	// #49: Re-enable foreign keys after truncation
   224  	qb.WriteString(" SET SESSION FOREIGN_KEY_CHECKS = 1;")
   225  
   226  	return tx.RawQuery(qb.String()).Exec()
   227  }
   228  
   229  func newMySQL(deets *ConnectionDetails) (dialect, error) {
   230  	cd := &mysql{
   231  		commonDialect: commonDialect{ConnectionDetails: deets},
   232  	}
   233  	return cd, nil
   234  }
   235  
   236  func urlParserMySQL(cd *ConnectionDetails) error {
   237  	cfg, err := _mysql.ParseDSN(strings.TrimPrefix(cd.URL, "mysql://"))
   238  	if err != nil {
   239  		return fmt.Errorf("the URL '%s' is not supported by MySQL driver: %w", cd.URL, err)
   240  	}
   241  
   242  	cd.User = cfg.User
   243  	cd.Password = cfg.Passwd
   244  	cd.Database = cfg.DBName
   245  
   246  	// NOTE: use cfg.Params if want to fill options with full parameters
   247  	cd.setOption("collation", cfg.Collation)
   248  
   249  	if cfg.Net == "unix" {
   250  		cd.Port = "socket" // trick. see: `URL()`
   251  		cd.Host = cfg.Addr
   252  	} else {
   253  		tmp := strings.Split(cfg.Addr, ":")
   254  		cd.Host = tmp[0]
   255  		if len(tmp) > 1 {
   256  			cd.Port = tmp[1]
   257  		}
   258  	}
   259  
   260  	return nil
   261  }
   262  
   263  func finalizerMySQL(cd *ConnectionDetails) {
   264  	cd.Host = defaults.String(cd.Host, hostMySQL)
   265  	cd.Port = defaults.String(cd.Port, portMySQL)
   266  
   267  	defs := map[string]string{
   268  		"readTimeout": "3s",
   269  		"collation":   "utf8mb4_general_ci",
   270  	}
   271  	forced := map[string]string{
   272  		"parseTime":       "true",
   273  		"multiStatements": "true",
   274  	}
   275  
   276  	for k, def := range defs {
   277  		cd.setOptionWithDefault(k, cd.option(k), def)
   278  	}
   279  
   280  	for k, v := range forced {
   281  		// respect user specified options but print warning!
   282  		cd.setOptionWithDefault(k, cd.option(k), v)
   283  		if cd.option(k) != v { // when user-defined option exists
   284  			log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.option(k))
   285  			log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.option(k))
   286  		} // or override with `cd.Options[k] = v`?
   287  		if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
   288  			log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
   289  		} // or fix user specified url?
   290  	}
   291  }
   292  
   293  const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name <> ? AND table_type <> 'VIEW'"