github.com/dolthub/go-mysql-server@v0.18.0/enginetest/mysqlshim/database.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysqlshim 16 17 import ( 18 "fmt" 19 "io" 20 "strings" 21 "time" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 ) 25 26 // Database represents a database for a local MySQL server. 27 type Database struct { 28 shim *MySQLShim 29 name string 30 } 31 32 var _ sql.Database = Database{} 33 var _ sql.TableCreator = Database{} 34 var _ sql.TableDropper = Database{} 35 var _ sql.TableRenamer = Database{} 36 var _ sql.TriggerDatabase = Database{} 37 var _ sql.StoredProcedureDatabase = Database{} 38 var _ sql.ViewDatabase = Database{} 39 var _ sql.EventDatabase = Database{} 40 41 // Name implements the interface sql.Database. 42 func (d Database) Name() string { 43 return d.name 44 } 45 46 // GetTableInsensitive implements the interface sql.Database. 47 func (d Database) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) { 48 tables, err := d.GetTableNames(ctx) 49 if err != nil { 50 return nil, false, err 51 } 52 lowerName := strings.ToLower(tblName) 53 for _, readName := range tables { 54 if lowerName == strings.ToLower(readName) { 55 return Table{d, readName}, true, nil 56 } 57 } 58 return nil, false, nil 59 } 60 61 // GetTableNames implements the interface sql.Database. 62 func (d Database) GetTableNames(ctx *sql.Context) ([]string, error) { 63 rows, err := d.shim.Query(d.name, "SHOW TABLES;") 64 if err != nil { 65 return nil, err 66 } 67 defer rows.Close(ctx) 68 var tableNames []string 69 var row sql.Row 70 for row, err = rows.Next(ctx); err == nil; row, err = rows.Next(ctx) { 71 tableNames = append(tableNames, row[0].(string)) 72 } 73 if err != io.EOF { 74 return nil, err 75 } 76 return tableNames, nil 77 } 78 79 // CreateTable implements the interface sql.TableCreator. 80 func (d Database) CreateTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string) error { 81 colStmts := make([]string, len(schema.Schema)) 82 var primaryKeyCols []string 83 for i, col := range schema.Schema { 84 stmt := fmt.Sprintf(" `%s` %s", col.Name, col.Type.String()) 85 if !col.Nullable { 86 stmt = fmt.Sprintf("%s NOT NULL", stmt) 87 } 88 if col.AutoIncrement { 89 stmt = fmt.Sprintf("%s AUTO_INCREMENT", stmt) 90 } 91 if col.Default != nil { 92 stmt = fmt.Sprintf("%s DEFAULT %s", stmt, col.Default.String()) 93 } 94 if col.Comment != "" { 95 stmt = fmt.Sprintf("%s COMMENT '%s'", stmt, col.Comment) 96 } 97 if col.PrimaryKey { 98 primaryKeyCols = append(primaryKeyCols, col.Name) 99 } 100 colStmts[i] = stmt 101 } 102 if len(primaryKeyCols) > 0 { 103 primaryKey := fmt.Sprintf(" PRIMARY KEY (`%s`)", strings.Join(primaryKeyCols, "`,`")) 104 colStmts = append(colStmts, primaryKey) 105 } 106 return d.shim.Exec(d.name, fmt.Sprintf("CREATE TABLE `%s` (\n%s\n) ENGINE=InnoDB DEFAULT COLLATE=%s COMMENT='%s';", 107 name, strings.Join(colStmts, ",\n"), sql.Collation_Default.String(), comment)) 108 } 109 110 // DropTable implements the interface sql.TableDropper. 111 func (d Database) DropTable(ctx *sql.Context, name string) error { 112 return d.shim.Exec(d.name, fmt.Sprintf("DROP TABLE `%s`;", name)) 113 } 114 115 // RenameTable implements the interface sql.TableRenamer. 116 func (d Database) RenameTable(ctx *sql.Context, oldName, newName string) error { 117 return d.shim.Exec(d.name, fmt.Sprintf("RENAME TABLE `%s` TO `%s`;", oldName, newName)) 118 } 119 120 // GetTriggers implements the interface sql.TriggerDatabase. 121 func (d Database) GetTriggers(ctx *sql.Context) ([]sql.TriggerDefinition, error) { 122 rows, err := d.shim.Query(d.name, "SHOW TRIGGERS;") 123 if err != nil { 124 return nil, err 125 } 126 defer rows.Close(ctx) 127 var triggers []sql.TriggerDefinition 128 var row sql.Row 129 for row, err = rows.Next(ctx); err == nil; row, err = rows.Next(ctx) { 130 // Trigger, Event, Table, Statement, Timing, Created, sql_mode, ... 131 triggers = append(triggers, sql.TriggerDefinition{ 132 Name: row[0].(string), 133 CreateStatement: fmt.Sprintf("CREATE TRIGGER `%s` %s %s ON `%s` FOR EACH ROW %s;", 134 row[0].(string), row[4].(string), row[1].(string), row[2].(string), row[3].(string)), 135 CreatedAt: time.Time{}, // TODO: time works in with doltharness 136 }) 137 } 138 if err != io.EOF { 139 return nil, err 140 } 141 return triggers, nil 142 } 143 144 // CreateTrigger implements the interface sql.TriggerDatabase. 145 func (d Database) CreateTrigger(ctx *sql.Context, definition sql.TriggerDefinition) error { 146 return d.shim.Exec(d.name, definition.CreateStatement) 147 } 148 149 // DropTrigger implements the interface sql.TriggerDatabase. 150 func (d Database) DropTrigger(ctx *sql.Context, name string) error { 151 return d.shim.Exec(d.name, fmt.Sprintf("DROP TRIGGER `%s`;", name)) 152 } 153 154 // GetStoredProcedure implements the interface sql.StoredProcedureDatabase. 155 func (d Database) GetStoredProcedure(ctx *sql.Context, name string) (sql.StoredProcedureDetails, bool, error) { 156 name = strings.ToLower(name) 157 procedures, err := d.GetStoredProcedures(ctx) 158 if err != nil { 159 return sql.StoredProcedureDetails{}, false, err 160 } 161 for _, procedure := range procedures { 162 if name == strings.ToLower(procedure.Name) { 163 return procedure, true, nil 164 } 165 } 166 return sql.StoredProcedureDetails{}, false, nil 167 } 168 169 // GetStoredProcedures implements the interface sql.StoredProcedureDatabase. 170 func (d Database) GetStoredProcedures(ctx *sql.Context) ([]sql.StoredProcedureDetails, error) { 171 procedures, err := d.shim.QueryRows("", fmt.Sprintf("SHOW PROCEDURE STATUS WHERE Db = '%s';", d.name)) 172 if err != nil { 173 return nil, err 174 } 175 storedProcedureDetails := make([]sql.StoredProcedureDetails, len(procedures)) 176 for i, procedure := range procedures { 177 // Db, Name, Type, Definer, Modified, Created, Security_type, Comment, ... 178 procedureStatement, err := d.shim.QueryRows("", fmt.Sprintf("SHOW CREATE PROCEDURE `%s`.`%s`;", d.name, procedure[1])) 179 if err != nil { 180 return nil, err 181 } 182 // Procedure, sql_mode, Create Procedure, ... 183 storedProcedureDetails[i] = sql.StoredProcedureDetails{ 184 Name: procedureStatement[0][0].(string), 185 CreateStatement: procedureStatement[0][2].(string), 186 CreatedAt: time.Time{}, // these should be added someday 187 ModifiedAt: time.Time{}, 188 } 189 } 190 return storedProcedureDetails, nil 191 } 192 193 // SaveStoredProcedure implements the interface sql.StoredProcedureDatabase. 194 func (d Database) SaveStoredProcedure(ctx *sql.Context, spd sql.StoredProcedureDetails) error { 195 return d.shim.Exec(d.name, spd.CreateStatement) 196 } 197 198 // DropStoredProcedure implements the interface sql.StoredProcedureDatabase. 199 func (d Database) DropStoredProcedure(ctx *sql.Context, name string) error { 200 return d.shim.Exec(d.name, fmt.Sprintf("DROP PROCEDURE `%s`;", name)) 201 } 202 203 // GetEvent implements sql.EventDatabase 204 func (d Database) GetEvent(ctx *sql.Context, name string) (sql.EventDefinition, bool, error) { 205 name = strings.ToLower(name) 206 events, _, err := d.GetEvents(ctx) 207 if err != nil { 208 return sql.EventDefinition{}, false, err 209 } 210 for _, event := range events { 211 if name == strings.ToLower(event.Name) { 212 return event, true, nil 213 } 214 } 215 return sql.EventDefinition{}, false, nil 216 } 217 218 // GetEvents implements sql.EventDatabase 219 func (d Database) GetEvents(_ *sql.Context) ([]sql.EventDefinition, interface{}, error) { 220 events, err := d.shim.QueryRows("", fmt.Sprintf("SHOW EVENTS WHERE Db = '%s';", d.name)) 221 if err != nil { 222 return nil, nil, err 223 } 224 eventDefinition := make([]sql.EventDefinition, len(events)) 225 for i, event := range events { 226 // Db, Name, Definer, Time Zone, Type, ... 227 eventStmt, err := d.shim.QueryRows("", fmt.Sprintf("SHOW CREATE EVENT `%s`.`%s`;", d.name, event[1])) 228 if err != nil { 229 return nil, nil, err 230 } 231 // Event, sql_mode, time_zone, Create Event, ... 232 eventDefinition[i] = sql.EventDefinition{ 233 Name: eventStmt[0][0].(string), 234 // TODO: other fields should be added such as Created, LastAltered 235 } 236 } 237 // MySQL shim doesn't support event reloading, so token is always nil 238 return eventDefinition, nil, nil 239 } 240 241 // SaveEvent implements sql.EventDatabase 242 func (d Database) SaveEvent(ctx *sql.Context, event sql.EventDefinition) (bool, error) { 243 return event.Status == sql.EventStatus_Enable.String(), d.shim.Exec(d.name, event.CreateEventStatement()) 244 } 245 246 // DropEvent implements sql.EventDatabase 247 func (d Database) DropEvent(ctx *sql.Context, name string) error { 248 return d.shim.Exec(d.name, fmt.Sprintf("DROP EVENT `%s`;", name)) 249 } 250 251 // UpdateEvent implements sql.EventDatabase 252 func (d Database) UpdateEvent(_ *sql.Context, originalName string, event sql.EventDefinition) (bool, error) { 253 err := d.shim.Exec(d.name, fmt.Sprintf("DROP EVENT `%s`;", originalName)) 254 if err != nil { 255 return false, err 256 } 257 return event.Status == sql.EventStatus_Enable.String(), d.shim.Exec(d.name, event.CreateEventStatement()) 258 } 259 260 // NeedsToReloadEvents implements sql.EventDatabase 261 func (d Database) NeedsToReloadEvents(_ *sql.Context, _ interface{}) (bool, error) { 262 // mysqlshim does not support event reloading 263 return false, nil 264 } 265 266 // UpdateLastExecuted implements sql.EventDatabase 267 func (d Database) UpdateLastExecuted(ctx *sql.Context, eventName string, lastExecuted time.Time) error { 268 return nil 269 } 270 271 // CreateView implements the interface sql.ViewDatabase. 272 func (d Database) CreateView(ctx *sql.Context, name string, selectStatement, createViewStmt string) error { 273 return d.shim.Exec(d.name, createViewStmt) 274 } 275 276 // DropView implements the interface sql.ViewDatabase. 277 func (d Database) DropView(ctx *sql.Context, name string) error { 278 return d.shim.Exec(d.name, fmt.Sprintf("DROP VIEW `%s`;", name)) 279 } 280 281 // GetViewDefinition implements the interface sql.ViewDatabase. 282 func (d Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.ViewDefinition, bool, error) { 283 views, err := d.AllViews(ctx) 284 if err != nil { 285 return sql.ViewDefinition{}, false, err 286 } 287 lowerName := strings.ToLower(viewName) 288 for _, view := range views { 289 if lowerName == strings.ToLower(view.Name) { 290 return view, true, nil 291 } 292 } 293 return sql.ViewDefinition{}, false, nil 294 } 295 296 // AllViews implements the interface sql.ViewDatabase. 297 func (d Database) AllViews(ctx *sql.Context) ([]sql.ViewDefinition, error) { 298 views, err := d.shim.QueryRows("", fmt.Sprintf("SELECT * FROM information_schema.TABLES WHERE TABLE_SCHEMA = '%s' AND TABLE_TYPE = 'VIEW';", d.name)) 299 if err != nil { 300 return nil, err 301 } 302 viewDefinitions := make([]sql.ViewDefinition, len(views)) 303 for i, view := range views { 304 viewName := view[2].(string) 305 viewStatementRow, err := d.shim.QueryRows("", fmt.Sprintf("SHOW CREATE VIEW `%s`.`%s`;", d.name, viewName)) 306 if err != nil { 307 return nil, err 308 } 309 createViewStatement := viewStatementRow[0][1].(string) 310 viewStatement := createViewStatement[strings.Index(createViewStatement, " AS ")+4:] // not the best but works for now 311 viewDefinitions[i] = sql.ViewDefinition{ 312 Name: viewName, 313 TextDefinition: viewStatement, 314 CreateViewStatement: createViewStatement, 315 } 316 } 317 return viewDefinitions, nil 318 }