github.com/dkishere/pop/v6@v6.103.1/dialect_mysql.go (about) 1 package pop 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "os/exec" 8 "strings" 9 10 "github.com/dkishere/pop/v6/columns" 11 "github.com/dkishere/pop/v6/internal/defaults" 12 "github.com/dkishere/pop/v6/logging" 13 _mysql "github.com/go-sql-driver/mysql" // Load MySQL Go driver 14 "github.com/gobuffalo/fizz" 15 "github.com/gobuffalo/fizz/translators" 16 ) 17 18 const nameMySQL = "mysql" 19 const hostMySQL = "localhost" 20 const portMySQL = "3306" 21 22 func init() { 23 AvailableDialects = append(AvailableDialects, nameMySQL) 24 urlParser[nameMySQL] = urlParserMySQL 25 finalizer[nameMySQL] = finalizerMySQL 26 newConnection[nameMySQL] = newMySQL 27 } 28 29 var _ dialect = &mysql{} 30 31 type mysql struct { 32 commonDialect 33 } 34 35 func (m *mysql) Name() string { 36 return nameMySQL 37 } 38 39 func (m *mysql) DefaultDriver() string { 40 return nameMySQL 41 } 42 43 func (mysql) Quote(key string) string { 44 return fmt.Sprintf("`%s`", key) 45 } 46 47 func (m *mysql) Details() *ConnectionDetails { 48 return m.ConnectionDetails 49 } 50 51 func (m *mysql) URL() string { 52 cd := m.ConnectionDetails 53 if cd.URL != "" { 54 return strings.TrimPrefix(cd.URL, "mysql://") 55 } 56 57 user := fmt.Sprintf("%s:%s@", cd.User, cd.Password) 58 user = strings.Replace(user, ":@", "@", 1) 59 if user == "@" || strings.HasPrefix(user, ":") { 60 user = "" 61 } 62 63 addr := fmt.Sprintf("(%s:%s)", cd.Host, cd.Port) 64 // in case of unix domain socket, tricky. 65 // it is better to check Host is not valid inet address or has '/'. 66 if cd.Port == "socket" { 67 addr = fmt.Sprintf("unix(%s)", cd.Host) 68 } 69 70 s := "%s%s/%s?%s" 71 return fmt.Sprintf(s, user, addr, cd.Database, cd.OptionsString("")) 72 } 73 74 func (m *mysql) urlWithoutDb() string { 75 cd := m.ConnectionDetails 76 return strings.Replace(m.URL(), "/"+cd.Database+"?", "/?", 1) 77 } 78 79 func (m *mysql) MigrationURL() string { 80 return m.URL() 81 } 82 83 func (m *mysql) Create(s store, model *Model, cols columns.Columns) error { 84 if err := genericCreate(s, model, cols, m); err != nil { 85 return fmt.Errorf("mysql create: %w", err) 86 } 87 return nil 88 } 89 90 func (m *mysql) Update(s store, model *Model, cols columns.Columns) error { 91 if err := genericUpdate(s, model, cols, m); err != nil { 92 return fmt.Errorf("mysql update: %w", err) 93 } 94 return nil 95 } 96 97 func (m *mysql) Destroy(s store, model *Model) error { 98 stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", m.Quote(model.TableName()), model.IDField()) 99 _, err := genericExec(s, stmt, model.ID()) 100 if err != nil { 101 return fmt.Errorf("mysql destroy: %w", err) 102 } 103 return nil 104 } 105 106 func (m *mysql) Delete(s store, model *Model, query Query) error { 107 sb := query.toSQLBuilder(model) 108 109 sql := fmt.Sprintf("DELETE FROM %s", m.Quote(model.TableName())) 110 sql = sb.buildWhereClauses(sql) 111 112 _, err := genericExec(s, sql, sb.Args()...) 113 if err != nil { 114 return fmt.Errorf("mysql delete: %w", err) 115 } 116 return nil 117 } 118 119 func (m *mysql) SelectOne(s store, model *Model, query Query) error { 120 if err := genericSelectOne(s, model, query); err != nil { 121 return fmt.Errorf("mysql select one: %w", err) 122 } 123 return nil 124 } 125 126 func (m *mysql) SelectMany(s store, models *Model, query Query) error { 127 if err := genericSelectMany(s, models, query); err != nil { 128 return fmt.Errorf("mysql select many: %w", err) 129 } 130 return nil 131 } 132 133 // CreateDB creates a new database, from the given connection credentials 134 func (m *mysql) CreateDB() error { 135 deets := m.ConnectionDetails 136 db, err := openPotentiallyInstrumentedConnection(m, m.urlWithoutDb()) 137 if err != nil { 138 return fmt.Errorf("error creating MySQL database %s: %w", deets.Database, err) 139 } 140 defer db.Close() 141 charset := defaults.String(deets.Options["charset"], "utf8mb4") 142 encoding := defaults.String(deets.Options["collation"], "utf8mb4_general_ci") 143 query := fmt.Sprintf("CREATE DATABASE `%s` DEFAULT CHARSET `%s` DEFAULT COLLATE `%s`", deets.Database, charset, encoding) 144 log(logging.SQL, query) 145 146 _, err = db.Exec(query) 147 if err != nil { 148 return fmt.Errorf("error creating MySQL database %s: %w", deets.Database, err) 149 } 150 151 log(logging.Info, "created database %s", deets.Database) 152 return nil 153 } 154 155 // DropDB drops an existing database, from the given connection credentials 156 func (m *mysql) DropDB() error { 157 deets := m.ConnectionDetails 158 db, err := openPotentiallyInstrumentedConnection(m, m.urlWithoutDb()) 159 if err != nil { 160 return fmt.Errorf("error dropping MySQL database %s: %w", deets.Database, err) 161 } 162 defer db.Close() 163 query := fmt.Sprintf("DROP DATABASE `%s`", deets.Database) 164 log(logging.SQL, query) 165 166 _, err = db.Exec(query) 167 if err != nil { 168 return fmt.Errorf("error dropping MySQL database %s: %w", deets.Database, err) 169 } 170 171 log(logging.Info, "dropped database %s", deets.Database) 172 return nil 173 } 174 175 func (m *mysql) TranslateSQL(sql string) string { 176 return sql 177 } 178 179 func (m *mysql) FizzTranslator() fizz.Translator { 180 t := translators.NewMySQL(m.URL(), m.Details().Database) 181 return t 182 } 183 184 func (m *mysql) DumpSchema(w io.Writer) error { 185 deets := m.Details() 186 // Github CI is currently using mysql:5.7 but the mysqldump version doesn't seem to match 187 cmd := exec.Command("mysqldump", "--column-statistics=0", "-d", "-h", deets.Host, "-P", deets.Port, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database) 188 if deets.Port == "socket" { 189 cmd = exec.Command("mysqldump", "--column-statistics=0", "-d", "-S", deets.Host, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database) 190 } 191 return genericDumpSchema(deets, cmd, w) 192 } 193 194 // LoadSchema executes a schema sql file against the configured database. 195 func (m *mysql) LoadSchema(r io.Reader) error { 196 return genericLoadSchema(m, r) 197 } 198 199 // TruncateAll truncates all tables for the given connection. 200 func (m *mysql) TruncateAll(tx *Connection) error { 201 var stmts []string 202 err := tx.RawQuery(mysqlTruncate, m.Details().Database, tx.MigrationTableName()).All(&stmts) 203 if err != nil { 204 return err 205 } 206 if len(stmts) == 0 { 207 return nil 208 } 209 210 var qb bytes.Buffer 211 // #49: Disable foreign keys before truncation 212 qb.WriteString("SET SESSION FOREIGN_KEY_CHECKS = 0; ") 213 qb.WriteString(strings.Join(stmts, " ")) 214 // #49: Re-enable foreign keys after truncation 215 qb.WriteString(" SET SESSION FOREIGN_KEY_CHECKS = 1;") 216 217 return tx.RawQuery(qb.String()).Exec() 218 } 219 220 func newMySQL(deets *ConnectionDetails) (dialect, error) { 221 cd := &mysql{ 222 commonDialect: commonDialect{ConnectionDetails: deets}, 223 } 224 return cd, nil 225 } 226 227 func urlParserMySQL(cd *ConnectionDetails) error { 228 cfg, err := _mysql.ParseDSN(strings.TrimPrefix(cd.URL, "mysql://")) 229 if err != nil { 230 return fmt.Errorf("the URL '%s' is not supported by MySQL driver: %w", cd.URL, err) 231 } 232 233 cd.User = cfg.User 234 cd.Password = cfg.Passwd 235 cd.Database = cfg.DBName 236 if cd.Options == nil { // prevent panic 237 cd.Options = make(map[string]string) 238 } 239 // NOTE: use cfg.Params if want to fill options with full parameters 240 cd.Options["collation"] = cfg.Collation 241 if cfg.Net == "unix" { 242 cd.Port = "socket" // trick. see: `URL()` 243 cd.Host = cfg.Addr 244 } else { 245 tmp := strings.Split(cfg.Addr, ":") 246 cd.Host = tmp[0] 247 if len(tmp) > 1 { 248 cd.Port = tmp[1] 249 } 250 } 251 252 return nil 253 } 254 255 func finalizerMySQL(cd *ConnectionDetails) { 256 cd.Host = defaults.String(cd.Host, hostMySQL) 257 cd.Port = defaults.String(cd.Port, portMySQL) 258 259 defs := map[string]string{ 260 "readTimeout": "3s", 261 "collation": "utf8mb4_general_ci", 262 } 263 forced := map[string]string{ 264 "parseTime": "true", 265 "multiStatements": "true", 266 } 267 268 if cd.Options == nil { // prevent panic 269 cd.Options = make(map[string]string) 270 } 271 272 for k, v := range defs { 273 cd.Options[k] = defaults.String(cd.Options[k], v) 274 } 275 276 for k, v := range forced { 277 // respect user specified options but print warning! 278 cd.Options[k] = defaults.String(cd.Options[k], v) 279 if cd.Options[k] != v { // when user-defined option exists 280 log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.Options[k]) 281 log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.Options[k]) 282 } // or override with `cd.Options[k] = v`? 283 if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) { 284 log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v) 285 } // or fix user specified url? 286 } 287 } 288 289 const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name <> ? AND table_type <> 'VIEW'"