github.com/dolthub/go-mysql-server@v0.18.0/enginetest/mysqlshim/table.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysqlshim 16 17 import ( 18 "fmt" 19 "math/rand" 20 "strings" 21 22 "github.com/dolthub/go-mysql-server/sql/planbuilder" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/plan" 26 "github.com/dolthub/go-mysql-server/sql/types" 27 ) 28 29 // Table represents a table for a local MySQL server. 30 type Table struct { 31 db Database 32 name string 33 } 34 35 var _ sql.Table = Table{} 36 var _ sql.InsertableTable = Table{} 37 var _ sql.UpdatableTable = Table{} 38 var _ sql.DeletableTable = Table{} 39 var _ sql.ReplaceableTable = Table{} 40 var _ sql.TruncateableTable = Table{} 41 var _ sql.IndexAddressableTable = Table{} 42 var _ sql.AlterableTable = Table{} 43 var _ sql.IndexAlterableTable = Table{} 44 var _ sql.ForeignKeyTable = Table{} 45 var _ sql.CheckAlterableTable = Table{} 46 var _ sql.CheckTable = Table{} 47 var _ sql.StatisticsTable = Table{} 48 var _ sql.PrimaryKeyAlterableTable = Table{} 49 50 func (t Table) IndexedAccess(sql.IndexLookup) sql.IndexedTable { 51 panic("not implemented") 52 } 53 54 func (t Table) PreciseMatch() bool { 55 return true 56 } 57 58 func (t Table) IndexedPartitions(ctx *sql.Context, _ sql.IndexLookup) (sql.PartitionIter, error) { 59 return t.Partitions(ctx) 60 } 61 62 // Name implements the interface sql.Table. 63 func (t Table) Name() string { 64 return t.name 65 } 66 67 // String implements the interface sql.Table. 68 func (t Table) String() string { 69 return t.name 70 } 71 72 // Schema implements the interface sql.Table. 73 func (t Table) Schema() sql.Schema { 74 createTable, err := t.getCreateTable() 75 if err != nil { 76 panic(err) 77 } 78 return createTable.Schema() 79 } 80 81 // Collation implements the interface sql.Table. 82 func (t Table) Collation() sql.CollationID { 83 return sql.Collation_Default 84 } 85 86 // Pks implements sql.PrimaryKeyAlterableTable 87 func (t Table) Pks() []sql.IndexColumn { 88 createTable, err := t.getCreateTable() 89 if err != nil { 90 panic(err) 91 } 92 93 pkSch := createTable.PkSchema() 94 pkCols := make([]sql.IndexColumn, len(pkSch.PkOrdinals)) 95 for i, j := range pkSch.PkOrdinals { 96 col := pkSch.Schema[j] 97 pkCols[i] = sql.IndexColumn{Name: col.Name} 98 } 99 return pkCols 100 } 101 102 // PrimaryKeySchema implements sql.PrimaryKeyAlterableTable 103 func (t Table) PrimaryKeySchema() sql.PrimaryKeySchema { 104 createTable, err := t.getCreateTable() 105 if err != nil { 106 panic(err) 107 } 108 return createTable.PkSchema() 109 } 110 111 // Partitions implements the interface sql.Table. 112 func (t Table) Partitions(ctx *sql.Context) (sql.PartitionIter, error) { 113 return &tablePartitionIter{}, nil 114 } 115 116 // PartitionRows implements the interface sql.Table. 117 func (t Table) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) { 118 return t.db.shim.Query(t.db.name, fmt.Sprintf("SELECT * FROM `%s`;", t.name)) 119 } 120 121 // Inserter implements the interface sql.InsertableTable. 122 func (t Table) Inserter(ctx *sql.Context) sql.RowInserter { 123 return &tableEditor{t, t.Schema()} 124 } 125 126 // Updater implements the interface sql.UpdatableTable. 127 func (t Table) Updater(ctx *sql.Context) sql.RowUpdater { 128 return &tableEditor{t, t.Schema()} 129 } 130 131 // Deleter implements the interface sql.DeletableTable. 132 func (t Table) Deleter(ctx *sql.Context) sql.RowDeleter { 133 return &tableEditor{t, t.Schema()} 134 } 135 136 // Replacer implements the interface sql.ReplaceableTable. 137 func (t Table) Replacer(ctx *sql.Context) sql.RowReplacer { 138 return &tableEditor{t, t.Schema()} 139 } 140 141 // Truncate implements the interface sql.TruncateableTable. 142 func (t Table) Truncate(ctx *sql.Context) (int, error) { 143 rows, err := t.db.shim.QueryRows(t.db.name, fmt.Sprintf("SELECT COUNT(*) FROM `%s`;", t.name)) 144 if err != nil { 145 return 0, err 146 } 147 rowCount, _, err := types.Int64.Convert(rows[0][0]) 148 if err != nil { 149 return 0, err 150 } 151 err = t.db.shim.Exec("", fmt.Sprintf("TRUNCATE TABLE `%s`;", t.name)) 152 return int(rowCount.(int64)), err 153 } 154 155 // AddColumn implements the interface sql.AlterableTable. 156 func (t Table) AddColumn(ctx *sql.Context, column *sql.Column, order *sql.ColumnOrder) error { 157 statement := fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN `%s` %s;", t.name, column.Name, strings.ToUpper(column.Type.String())) 158 if !column.Nullable { 159 statement = fmt.Sprintf("%s NOT NULL", statement) 160 } 161 if column.AutoIncrement { 162 statement = fmt.Sprintf("%s AUTO_INCREMENT", statement) 163 } 164 if column.Default != nil { 165 statement = fmt.Sprintf("%s DEFAULT %s", statement, column.Default.String()) 166 } 167 if column.Comment != "" { 168 statement = fmt.Sprintf("%s COMMENT '%s'", statement, column.Comment) 169 } 170 if order != nil { 171 if order.First { 172 statement = fmt.Sprintf("%s FIRST", statement) 173 } else if len(order.AfterColumn) > 0 { 174 statement = fmt.Sprintf("%s AFTER `%s`", statement, order.AfterColumn) 175 } 176 } 177 return t.db.shim.Exec(t.db.name, statement) 178 } 179 180 // DropColumn implements the interface sql.AlterableTable. 181 func (t Table) DropColumn(ctx *sql.Context, columnName string) error { 182 return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN `%s`;", t.name, columnName)) 183 } 184 185 // ModifyColumn implements the interface sql.AlterableTable. 186 func (t Table) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Column, order *sql.ColumnOrder) error { 187 statement := fmt.Sprintf("ALTER TABLE `%s` CHANGE COLUMN `%s` `%s` %s;", t.name, columnName, column.Name, strings.ToUpper(column.Type.String())) 188 if !column.Nullable { 189 statement = fmt.Sprintf("%s NOT NULL", statement) 190 } 191 if column.AutoIncrement { 192 statement = fmt.Sprintf("%s AUTO_INCREMENT", statement) 193 } 194 if column.Default != nil { 195 statement = fmt.Sprintf("%s DEFAULT %s", statement, column.Default.String()) 196 } 197 if column.Comment != "" { 198 statement = fmt.Sprintf("%s COMMENT '%s'", statement, column.Comment) 199 } 200 if order != nil { 201 if order.First { 202 statement = fmt.Sprintf("%s FIRST", statement) 203 } else if len(order.AfterColumn) > 0 { 204 statement = fmt.Sprintf("%s AFTER `%s`", statement, order.AfterColumn) 205 } 206 } 207 return t.db.shim.Exec(t.db.name, statement) 208 } 209 210 // CreateIndex implements the interface sql.IndexAlterableTable. 211 func (t Table) CreateIndex(ctx *sql.Context, idx sql.IndexDef) error { 212 statement := "CREATE" 213 switch idx.Constraint { 214 case sql.IndexConstraint_Unique: 215 statement += " UNIQUE INDEX" 216 case sql.IndexConstraint_Fulltext: 217 statement += " FULLTEXT INDEX" 218 case sql.IndexConstraint_Spatial: 219 statement += " SPATIAL INDEX" 220 default: 221 statement += " INDEX" 222 } 223 idxColumnNames := make([]string, len(idx.Columns)) 224 for i, column := range idx.Columns { 225 idxColumnNames[i] = column.Name 226 } 227 if len(idx.Name) == 0 { 228 idx.Name = randString(10) 229 } 230 statement = fmt.Sprintf("%s `%s` ON `%s` (`%s`)", statement, idx.Name, t.name, strings.Join(idxColumnNames, "`,`")) 231 if len(idx.Comment) > 0 { 232 statement = fmt.Sprintf("%s COMMENT '%s'", statement, strings.ReplaceAll(idx.Comment, "'", `\'`)) 233 } 234 return t.db.shim.Exec(t.db.name, statement) 235 } 236 237 // DropIndex implements the interface sql.IndexAlterableTable. 238 func (t Table) DropIndex(ctx *sql.Context, indexName string) error { 239 return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`;", t.name, indexName)) 240 } 241 242 // RenameIndex implements the interface sql.IndexAlterableTable. 243 func (t Table) RenameIndex(ctx *sql.Context, fromIndexName string, toIndexName string) error { 244 return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` RENAME INDEX `%s` TO `%s`;", t.name, fromIndexName, toIndexName)) 245 } 246 247 // GetIndexes implements the interface sql.IndexedTable. 248 func (t Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { 249 //TODO: add this along with some kind of index implementation 250 return nil, nil 251 } 252 253 // GetDeclaredForeignKeys implements the interface sql.ForeignKeyTable. 254 func (t Table) GetDeclaredForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) { 255 //TODO: add this 256 return nil, nil 257 } 258 259 // GetReferencedForeignKeys implements the interface sql.ForeignKeyTable. 260 func (t Table) GetReferencedForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) { 261 //TODO: add this 262 return nil, nil 263 } 264 265 // AddForeignKey implements the interface sql.ForeignKeyTable. 266 func (t Table) AddForeignKey(ctx *sql.Context, fk sql.ForeignKeyConstraint) error { 267 constraint := "" 268 if len(fk.Name) > 0 { 269 constraint = fmt.Sprintf(" CONSTRAINT `%s`", fk.Name) 270 } 271 onDeleteStr := "" 272 if fk.OnDelete != sql.ForeignKeyReferentialAction_DefaultAction { 273 onDeleteStr = fmt.Sprintf(" ON DELETE %s", string(fk.OnDelete)) 274 } 275 onUpdateStr := "" 276 if fk.OnUpdate != sql.ForeignKeyReferentialAction_DefaultAction { 277 onUpdateStr = fmt.Sprintf(" ON UPDATE %s", string(fk.OnUpdate)) 278 } 279 return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s`.`%s` ADD%s FOREIGN KEY (`%s`) REFERENCES `%s`.`%s` (`%s`)%s%s;", 280 fk.Database, t.name, constraint, strings.Join(fk.Columns, "`,`"), fk.ParentDatabase, fk.ParentTable, 281 strings.Join(fk.ParentColumns, "`,`"), onDeleteStr, onUpdateStr)) 282 } 283 284 // DropForeignKey implements the interface sql.ForeignKeyTable. 285 func (t Table) DropForeignKey(ctx *sql.Context, fkName string) error { 286 return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` DROP FOREIGN KEY `%s`;", t.name, fkName)) 287 } 288 289 // UpdateForeignKey implements the interface sql.ForeignKeyTable. 290 func (t Table) UpdateForeignKey(ctx *sql.Context, fkName string, fkDef sql.ForeignKeyConstraint) error { 291 // Will automatically be handled by MySQL 292 return nil 293 } 294 295 // CreateIndexForForeignKey implements the interface sql.ForeignKeyTable. 296 func (t Table) CreateIndexForForeignKey(ctx *sql.Context, idx sql.IndexDef) error { 297 return nil 298 } 299 300 // SetForeignKeyResolved implements the interface sql.ForeignKeyTable. 301 func (t Table) SetForeignKeyResolved(ctx *sql.Context, fkName string) error { 302 return nil 303 } 304 305 // GetForeignKeyEditor implements the interface sql.ForeignKeyTable. 306 func (t Table) GetForeignKeyEditor(ctx *sql.Context) sql.ForeignKeyEditor { 307 return &tableEditor{t, t.Schema()} 308 } 309 310 // CreateCheck implements the interface sql.CheckAlterableTable. 311 func (t Table) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error { 312 statement := fmt.Sprintf("ALTER TABLE `%s` ADD", t.name) 313 if len(check.Name) > 0 { 314 statement = fmt.Sprintf("%s CONSTRAINT `%s`", statement, check.Name) 315 } 316 statement = fmt.Sprintf("%s CHECK (%s)", statement, check.CheckExpression) 317 if !check.Enforced { 318 statement = fmt.Sprintf("%s NOT ENFORCED", statement) 319 } 320 return t.db.shim.Exec(t.db.name, statement) 321 } 322 323 // DropCheck implements the interface sql.CheckAlterableTable. 324 func (t Table) DropCheck(ctx *sql.Context, chName string) error { 325 return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` DROP CHECK `%s`;", t.name, chName)) 326 } 327 328 // GetChecks implements the interface sql.CheckTable. 329 func (t Table) GetChecks(ctx *sql.Context) ([]sql.CheckDefinition, error) { 330 //TODO: add this 331 return nil, nil 332 } 333 334 // Close implements the interface sql.AutoIncrementSetter. 335 func (t Table) Close(ctx *sql.Context) error { 336 return nil 337 } 338 339 // DataLength implements the interface sql.StatisticsTable. 340 func (t Table) DataLength(ctx *sql.Context) (uint64, error) { 341 // SELECT * FROM information_schema.TABLES WHERE (TABLE_SCHEMA = 'sys') AND (TABLE_NAME = 'test'); 342 rows, err := t.db.shim.QueryRows(t.db.name, fmt.Sprintf("SELECT COUNT(*) FROM `%s`;", t.name)) 343 if err != nil { 344 return 0, err 345 } 346 rowCount, _, err := types.Uint64.Convert(rows[0][0]) 347 if err != nil { 348 return 0, err 349 } 350 return rowCount.(uint64), nil 351 } 352 353 // Cardinality implements the interface sql.StatisticsTable. 354 func (t Table) RowCount(ctx *sql.Context) (uint64, bool, error) { 355 return 0, false, nil 356 } 357 358 // CreatePrimaryKey implements the interface sql.PrimaryKeyAlterableTable. 359 func (t Table) CreatePrimaryKey(ctx *sql.Context, columns []sql.IndexColumn) error { 360 pkNames := make([]string, len(columns)) 361 for i, column := range columns { 362 pkNames[i] = column.Name 363 } 364 return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` ADD PRIMARY KEY (`%s`);", t.name, strings.Join(pkNames, "`,`"))) 365 } 366 367 // DropPrimaryKey implements the interface sql.PrimaryKeyAlterableTable. 368 func (t Table) DropPrimaryKey(ctx *sql.Context) error { 369 return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` DROP PRIMARY KEY;", t.name)) 370 } 371 372 // getCreateTable returns this table as a CreateTable node. 373 func (t Table) getCreateTable() (*plan.CreateTable, error) { 374 rows, err := t.db.shim.QueryRows(t.db.name, fmt.Sprintf("SHOW CREATE TABLE `%s`;", t.name)) 375 if err != nil { 376 return nil, err 377 } 378 if len(rows) == 0 || len(rows[0]) == 0 { 379 return nil, sql.ErrTableNotFound.New(t.name) 380 } 381 // TODO add catalog 382 createTableNode, err := planbuilder.Parse(sql.NewEmptyContext(), sql.MapCatalog{Tables: map[string]sql.Table{t.name: t}}, rows[0][1].(string)) 383 if err != nil { 384 return nil, err 385 } 386 return createTableNode.(*plan.CreateTable), nil 387 } 388 389 // randString returns a random string of the given length. 390 // Retrieved from https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go 391 func randString(n int) string { 392 const letterIdxBits = 6 393 const letterIdxMask = 1<<letterIdxBits - 1 394 const letterIdxMax = 63 / letterIdxBits 395 const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 396 b := make([]byte, n) 397 // A rand.Int63() generates 63 random bits, enough for letterIdxMax letters! 398 for i, cache, remain := n-1, rand.Int63(), letterIdxMax; i >= 0; { 399 if remain == 0 { 400 cache, remain = rand.Int63(), letterIdxMax 401 } 402 if idx := int(cache & letterIdxMask); idx < len(letterBytes) { 403 b[i] = letterBytes[idx] 404 i-- 405 } 406 cache >>= letterIdxBits 407 remain-- 408 } 409 410 return string(b) 411 }