github.com/rjgonzale/pop/v5@v5.1.3-dev/dialect_mysql.go (about) 1 package pop 2 3 import ( 4 "bytes" 5 "database/sql" 6 "fmt" 7 "io" 8 "os/exec" 9 "strings" 10 11 // Load MySQL Go driver 12 _mysql "github.com/go-sql-driver/mysql" 13 "github.com/gobuffalo/fizz" 14 "github.com/gobuffalo/fizz/translators" 15 "github.com/gobuffalo/pop/v5/columns" 16 "github.com/gobuffalo/pop/v5/internal/defaults" 17 "github.com/gobuffalo/pop/v5/logging" 18 "github.com/pkg/errors" 19 ) 20 21 const nameMySQL = "mysql" 22 const hostMySQL = "localhost" 23 const portMySQL = "3306" 24 25 func init() { 26 AvailableDialects = append(AvailableDialects, nameMySQL) 27 urlParser[nameMySQL] = urlParserMySQL 28 finalizer[nameMySQL] = finalizerMySQL 29 newConnection[nameMySQL] = newMySQL 30 } 31 32 var _ dialect = &mysql{} 33 34 type mysql struct { 35 commonDialect 36 } 37 38 func (m *mysql) Name() string { 39 return nameMySQL 40 } 41 42 func (m *mysql) DefaultDriver() string { 43 return nameMySQL 44 } 45 46 func (mysql) Quote(key string) string { 47 return fmt.Sprintf("`%s`", key) 48 } 49 50 func (m *mysql) Details() *ConnectionDetails { 51 return m.ConnectionDetails 52 } 53 54 func (m *mysql) URL() string { 55 cd := m.ConnectionDetails 56 if cd.URL != "" { 57 return strings.TrimPrefix(cd.URL, "mysql://") 58 } 59 60 user := fmt.Sprintf("%s:%s@", cd.User, cd.Password) 61 user = strings.Replace(user, ":@", "@", 1) 62 if user == "@" || strings.HasPrefix(user, ":") { 63 user = "" 64 } 65 66 addr := fmt.Sprintf("(%s:%s)", cd.Host, cd.Port) 67 // in case of unix domain socket, tricky. 68 // it is better to check Host is not valid inet address or has '/'. 69 if cd.Port == "socket" { 70 addr = fmt.Sprintf("unix(%s)", cd.Host) 71 } 72 73 s := "%s%s/%s?%s" 74 return fmt.Sprintf(s, user, addr, cd.Database, cd.OptionsString("")) 75 } 76 77 func (m *mysql) urlWithoutDb() string { 78 cd := m.ConnectionDetails 79 return strings.Replace(m.URL(), "/"+cd.Database+"?", "/?", 1) 80 } 81 82 func (m *mysql) MigrationURL() string { 83 return m.URL() 84 } 85 86 func (m *mysql) Create(s store, model *Model, cols columns.Columns) error { 87 return errors.Wrap(genericCreate(s, model, cols, m), "mysql create") 88 } 89 90 func (m *mysql) Update(s store, model *Model, cols columns.Columns) error { 91 return errors.Wrap(genericUpdate(s, model, cols, m), "mysql update") 92 } 93 94 func (m *mysql) Destroy(s store, model *Model) error { 95 return errors.Wrap(genericDestroy(s, model, m), "mysql destroy") 96 } 97 98 func (m *mysql) SelectOne(s store, model *Model, query Query) error { 99 return errors.Wrap(genericSelectOne(s, model, query), "mysql select one") 100 } 101 102 func (m *mysql) SelectMany(s store, models *Model, query Query) error { 103 return errors.Wrap(genericSelectMany(s, models, query), "mysql select many") 104 } 105 106 // CreateDB creates a new database, from the given connection credentials 107 func (m *mysql) CreateDB() error { 108 deets := m.ConnectionDetails 109 db, err := sql.Open(deets.Dialect, m.urlWithoutDb()) 110 if err != nil { 111 return errors.Wrapf(err, "error creating MySQL database %s", deets.Database) 112 } 113 defer db.Close() 114 charset := defaults.String(deets.Options["charset"], "utf8mb4") 115 encoding := defaults.String(deets.Options["collation"], "utf8mb4_general_ci") 116 query := fmt.Sprintf("CREATE DATABASE `%s` DEFAULT CHARSET `%s` DEFAULT COLLATE `%s`", deets.Database, charset, encoding) 117 log(logging.SQL, query) 118 119 _, err = db.Exec(query) 120 if err != nil { 121 return errors.Wrapf(err, "error creating MySQL database %s", deets.Database) 122 } 123 124 log(logging.Info, "created database %s", deets.Database) 125 return nil 126 } 127 128 // DropDB drops an existing database, from the given connection credentials 129 func (m *mysql) DropDB() error { 130 deets := m.ConnectionDetails 131 db, err := sql.Open(deets.Dialect, m.urlWithoutDb()) 132 if err != nil { 133 return errors.Wrapf(err, "error dropping MySQL database %s", deets.Database) 134 } 135 defer db.Close() 136 query := fmt.Sprintf("DROP DATABASE `%s`", deets.Database) 137 log(logging.SQL, query) 138 139 _, err = db.Exec(query) 140 if err != nil { 141 return errors.Wrapf(err, "error dropping MySQL database %s", deets.Database) 142 } 143 144 log(logging.Info, "dropped database %s", deets.Database) 145 return nil 146 } 147 148 func (m *mysql) TranslateSQL(sql string) string { 149 return sql 150 } 151 152 func (m *mysql) FizzTranslator() fizz.Translator { 153 t := translators.NewMySQL(m.URL(), m.Details().Database) 154 return t 155 } 156 157 func (m *mysql) DumpSchema(w io.Writer) error { 158 deets := m.Details() 159 cmd := exec.Command("mysqldump", "-d", "-h", deets.Host, "-P", deets.Port, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database) 160 if deets.Port == "socket" { 161 cmd = exec.Command("mysqldump", "-d", "-S", deets.Host, "-u", deets.User, fmt.Sprintf("--password=%s", deets.Password), deets.Database) 162 } 163 return genericDumpSchema(deets, cmd, w) 164 } 165 166 // LoadSchema executes a schema sql file against the configured database. 167 func (m *mysql) LoadSchema(r io.Reader) error { 168 return genericLoadSchema(m.ConnectionDetails, m.MigrationURL(), r) 169 } 170 171 // TruncateAll truncates all tables for the given connection. 172 func (m *mysql) TruncateAll(tx *Connection) error { 173 var stmts []string 174 err := tx.RawQuery(mysqlTruncate, m.Details().Database, tx.MigrationTableName()).All(&stmts) 175 if err != nil { 176 return err 177 } 178 if len(stmts) == 0 { 179 return nil 180 } 181 182 var qb bytes.Buffer 183 // #49: Disable foreign keys before truncation 184 qb.WriteString("SET SESSION FOREIGN_KEY_CHECKS = 0; ") 185 qb.WriteString(strings.Join(stmts, " ")) 186 // #49: Re-enable foreign keys after truncation 187 qb.WriteString(" SET SESSION FOREIGN_KEY_CHECKS = 1;") 188 189 return tx.RawQuery(qb.String()).Exec() 190 } 191 192 func newMySQL(deets *ConnectionDetails) (dialect, error) { 193 cd := &mysql{ 194 commonDialect: commonDialect{ConnectionDetails: deets}, 195 } 196 return cd, nil 197 } 198 199 func urlParserMySQL(cd *ConnectionDetails) error { 200 cfg, err := _mysql.ParseDSN(strings.TrimPrefix(cd.URL, "mysql://")) 201 if err != nil { 202 return errors.Wrapf(err, "the URL '%s' is not supported by MySQL driver", cd.URL) 203 } 204 205 cd.User = cfg.User 206 cd.Password = cfg.Passwd 207 cd.Database = cfg.DBName 208 if cd.Options == nil { // prevent panic 209 cd.Options = make(map[string]string) 210 } 211 // NOTE: use cfg.Params if want to fill options with full parameters 212 cd.Options["collation"] = cfg.Collation 213 if cfg.Net == "unix" { 214 cd.Port = "socket" // trick. see: `URL()` 215 cd.Host = cfg.Addr 216 } else { 217 tmp := strings.Split(cfg.Addr, ":") 218 cd.Host = tmp[0] 219 if len(tmp) > 1 { 220 cd.Port = tmp[1] 221 } 222 } 223 224 return nil 225 } 226 227 func finalizerMySQL(cd *ConnectionDetails) { 228 cd.Host = defaults.String(cd.Host, hostMySQL) 229 cd.Port = defaults.String(cd.Port, portMySQL) 230 231 defs := map[string]string{ 232 "readTimeout": "3s", 233 "collation": "utf8mb4_general_ci", 234 } 235 forced := map[string]string{ 236 "parseTime": "true", 237 "multiStatements": "true", 238 } 239 240 if cd.Options == nil { // prevent panic 241 cd.Options = make(map[string]string) 242 } 243 244 for k, v := range defs { 245 cd.Options[k] = defaults.String(cd.Options[k], v) 246 } 247 248 for k, v := range forced { 249 // respect user specified options but print warning! 250 cd.Options[k] = defaults.String(cd.Options[k], v) 251 if cd.Options[k] != v { // when user-defined option exists 252 log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.Options[k]) 253 log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.Options[k]) 254 } // or override with `cd.Options[k] = v`? 255 if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) { 256 log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v) 257 } // or fix user specified url? 258 } 259 } 260 261 const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name <> ? AND table_type <> 'VIEW'"