github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/dialect_mysql.go (about) 1 package pop 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "os/exec" 8 "regexp" 9 "strings" 10 11 "github.com/Accefy/pop/internal/defaults" 12 _mysql "github.com/go-sql-driver/mysql" // Load MySQL Go driver 13 "github.com/gobuffalo/fizz" 14 "github.com/gobuffalo/fizz/translators" 15 "github.com/gobuffalo/pop/v6/columns" 16 "github.com/gobuffalo/pop/v6/logging" 17 "github.com/jmoiron/sqlx" 18 ) 19 20 const nameMySQL = "mysql" 21 const hostMySQL = "localhost" 22 const portMySQL = "3306" 23 24 func init() { 25 AvailableDialects = append(AvailableDialects, nameMySQL) 26 urlParser[nameMySQL] = urlParserMySQL 27 finalizer[nameMySQL] = finalizerMySQL 28 newConnection[nameMySQL] = newMySQL 29 } 30 31 var _ dialect = &mysql{} 32 33 type mysql struct { 34 commonDialect 35 } 36 37 func (m *mysql) Name() string { 38 return nameMySQL 39 } 40 41 func (m *mysql) DefaultDriver() string { 42 return nameMySQL 43 } 44 45 func (mysql) Quote(key string) string { 46 return fmt.Sprintf("`%s`", key) 47 } 48 49 func (m *mysql) Details() *ConnectionDetails { 50 return m.ConnectionDetails 51 } 52 53 func (m *mysql) URL() string { 54 cd := m.ConnectionDetails 55 if cd.URL != "" { 56 return strings.TrimPrefix(cd.URL, "mysql://") 57 } 58 59 user := fmt.Sprintf("%s:%s@", cd.User, cd.Password) 60 user = strings.Replace(user, ":@", "@", 1) 61 if user == "@" || strings.HasPrefix(user, ":") { 62 user = "" 63 } 64 65 addr := fmt.Sprintf("(%s:%s)", cd.Host, cd.Port) 66 // in case of unix domain socket, tricky. 67 // it is better to check Host is not valid inet address or has '/'. 68 if cd.Port == "socket" { 69 addr = fmt.Sprintf("unix(%s)", cd.Host) 70 } 71 72 s := "%s%s/%s?%s" 73 return fmt.Sprintf(s, user, addr, cd.Database, cd.OptionsString("")) 74 } 75 76 func (m *mysql) urlWithoutDb() string { 77 cd := m.ConnectionDetails 78 return strings.Replace(m.URL(), "/"+cd.Database+"?", "/?", 1) 79 } 80 81 func (m *mysql) MigrationURL() string { 82 return m.URL() 83 } 84 85 func (m *mysql) Create(c *Connection, model *Model, cols columns.Columns) error { 86 if err := genericCreate(c, model, cols, m); err != nil { 87 return fmt.Errorf("mysql create: %w", err) 88 } 89 return nil 90 } 91 92 func (m *mysql) Update(c *Connection, model *Model, cols columns.Columns) error { 93 if err := genericUpdate(c, model, cols, m); err != nil { 94 return fmt.Errorf("mysql update: %w", err) 95 } 96 return nil 97 } 98 99 func (m *mysql) UpdateQuery(c *Connection, model *Model, cols columns.Columns, query Query) (int64, error) { 100 if n, err := genericUpdateQuery(c, model, cols, m, query, sqlx.QUESTION); err != nil { 101 return n, fmt.Errorf("mysql update query: %w", err) 102 } else { 103 return n, nil 104 } 105 } 106 107 func (m *mysql) Destroy(c *Connection, model *Model) error { 108 stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", m.Quote(model.TableName()), model.IDField()) 109 _, err := genericExec(c, stmt, model.ID()) 110 if err != nil { 111 return fmt.Errorf("mysql destroy: %w", err) 112 } 113 return nil 114 } 115 116 var asRegex = regexp.MustCompile(`\sAS\s\S+`) // exactly " AS non-spaces" 117 118 func (m *mysql) Delete(c *Connection, model *Model, query Query) error { 119 sqlQuery, args := query.ToSQL(model) 120 // * MySQL does not support table alias for DELETE syntax until 8.0. 121 // * Do not generate SQL manually if they may have `WHERE IN`. 122 // * Spaces are intentionally added to make it easy to see on the log. 123 sqlQuery = asRegex.ReplaceAllString(sqlQuery, " ") 124 125 _, err := genericExec(c, sqlQuery, args...) 126 return err 127 } 128 129 func (m *mysql) SelectOne(c *Connection, model *Model, query Query) error { 130 if err := genericSelectOne(c, model, query); err != nil { 131 return fmt.Errorf("mysql select one: %w", err) 132 } 133 return nil 134 } 135 136 func (m *mysql) SelectMany(c *Connection, models *Model, query Query) error { 137 if err := genericSelectMany(c, models, query); err != nil { 138 return fmt.Errorf("mysql select many: %w", err) 139 } 140 return nil 141 } 142 143 // CreateDB creates a new database, from the given connection credentials 144 func (m *mysql) CreateDB() error { 145 deets := m.ConnectionDetails 146 db, err := openPotentiallyInstrumentedConnection(m, m.urlWithoutDb()) 147 if err != nil { 148 return fmt.Errorf("error creating MySQL database %s: %w", deets.Database, err) 149 } 150 defer db.Close() 151 charset := defaults.String(deets.option("charset"), "utf8mb4") 152 encoding := defaults.String(deets.option("collation"), "utf8mb4_general_ci") 153 query := fmt.Sprintf("CREATE DATABASE `%s` DEFAULT CHARSET `%s` DEFAULT COLLATE `%s`", deets.Database, charset, encoding) 154 log(logging.SQL, query) 155 156 _, err = db.Exec(query) 157 if err != nil { 158 return fmt.Errorf("error creating MySQL database %s: %w", deets.Database, err) 159 } 160 161 log(logging.Info, "created database %s", deets.Database) 162 return nil 163 } 164 165 // DropDB drops an existing database, from the given connection credentials 166 func (m *mysql) DropDB() error { 167 deets := m.ConnectionDetails 168 db, err := openPotentiallyInstrumentedConnection(m, m.urlWithoutDb()) 169 if err != nil { 170 return fmt.Errorf("error dropping MySQL database %s: %w", deets.Database, err) 171 } 172 defer db.Close() 173 query := fmt.Sprintf("DROP DATABASE `%s`", deets.Database) 174 log(logging.SQL, query) 175 176 _, err = db.Exec(query) 177 if err != nil { 178 return fmt.Errorf("error dropping MySQL database %s: %w", deets.Database, err) 179 } 180 181 log(logging.Info, "dropped database %s", deets.Database) 182 return nil 183 } 184 185 func (m *mysql) TranslateSQL(sql string) string { 186 return sql 187 } 188 189 func (m *mysql) FizzTranslator() fizz.Translator { 190 t := translators.NewMySQL(m.URL(), m.Details().Database) 191 return t 192 } 193 194 func (m *mysql) DumpSchema(w io.Writer) error { 195 deets := m.Details() 196 cmd := exec.Command("mysqldump", "-d", "-h", deets.Host, "-P", deets.Port, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database) 197 if deets.Port == "socket" { 198 cmd = exec.Command("mysqldump", "-d", "-S", deets.Host, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database) 199 } 200 return genericDumpSchema(deets, cmd, w) 201 } 202 203 // LoadSchema executes a schema sql file against the configured database. 204 func (m *mysql) LoadSchema(r io.Reader) error { 205 return genericLoadSchema(m, r) 206 } 207 208 // TruncateAll truncates all tables for the given connection. 209 func (m *mysql) TruncateAll(tx *Connection) error { 210 var stmts []string 211 err := tx.RawQuery(mysqlTruncate, m.Details().Database, tx.MigrationTableName()).All(&stmts) 212 if err != nil { 213 return err 214 } 215 if len(stmts) == 0 { 216 return nil 217 } 218 219 var qb bytes.Buffer 220 // #49: Disable foreign keys before truncation 221 qb.WriteString("SET SESSION FOREIGN_KEY_CHECKS = 0; ") 222 qb.WriteString(strings.Join(stmts, " ")) 223 // #49: Re-enable foreign keys after truncation 224 qb.WriteString(" SET SESSION FOREIGN_KEY_CHECKS = 1;") 225 226 return tx.RawQuery(qb.String()).Exec() 227 } 228 229 func newMySQL(deets *ConnectionDetails) (dialect, error) { 230 cd := &mysql{ 231 commonDialect: commonDialect{ConnectionDetails: deets}, 232 } 233 return cd, nil 234 } 235 236 func urlParserMySQL(cd *ConnectionDetails) error { 237 cfg, err := _mysql.ParseDSN(strings.TrimPrefix(cd.URL, "mysql://")) 238 if err != nil { 239 return fmt.Errorf("the URL '%s' is not supported by MySQL driver: %w", cd.URL, err) 240 } 241 242 cd.User = cfg.User 243 cd.Password = cfg.Passwd 244 cd.Database = cfg.DBName 245 246 // NOTE: use cfg.Params if want to fill options with full parameters 247 cd.setOption("collation", cfg.Collation) 248 249 if cfg.Net == "unix" { 250 cd.Port = "socket" // trick. see: `URL()` 251 cd.Host = cfg.Addr 252 } else { 253 tmp := strings.Split(cfg.Addr, ":") 254 cd.Host = tmp[0] 255 if len(tmp) > 1 { 256 cd.Port = tmp[1] 257 } 258 } 259 260 return nil 261 } 262 263 func finalizerMySQL(cd *ConnectionDetails) { 264 cd.Host = defaults.String(cd.Host, hostMySQL) 265 cd.Port = defaults.String(cd.Port, portMySQL) 266 267 defs := map[string]string{ 268 "readTimeout": "3s", 269 "collation": "utf8mb4_general_ci", 270 } 271 forced := map[string]string{ 272 "parseTime": "true", 273 "multiStatements": "true", 274 } 275 276 for k, def := range defs { 277 cd.setOptionWithDefault(k, cd.option(k), def) 278 } 279 280 for k, v := range forced { 281 // respect user specified options but print warning! 282 cd.setOptionWithDefault(k, cd.option(k), v) 283 if cd.option(k) != v { // when user-defined option exists 284 log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.option(k)) 285 log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.option(k)) 286 } // or override with `cd.Options[k] = v`? 287 if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) { 288 log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v) 289 } // or fix user specified url? 290 } 291 } 292 293 const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name <> ? AND table_type <> 'VIEW'"