github.com/dolthub/go-mysql-server@v0.18.0/enginetest/mysqlshim/table.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysqlshim
    16  
    17  import (
    18  	"fmt"
    19  	"math/rand"
    20  	"strings"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql/planbuilder"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  // Table represents a table for a local MySQL server.
    30  type Table struct {
    31  	db   Database
    32  	name string
    33  }
    34  
    35  var _ sql.Table = Table{}
    36  var _ sql.InsertableTable = Table{}
    37  var _ sql.UpdatableTable = Table{}
    38  var _ sql.DeletableTable = Table{}
    39  var _ sql.ReplaceableTable = Table{}
    40  var _ sql.TruncateableTable = Table{}
    41  var _ sql.IndexAddressableTable = Table{}
    42  var _ sql.AlterableTable = Table{}
    43  var _ sql.IndexAlterableTable = Table{}
    44  var _ sql.ForeignKeyTable = Table{}
    45  var _ sql.CheckAlterableTable = Table{}
    46  var _ sql.CheckTable = Table{}
    47  var _ sql.StatisticsTable = Table{}
    48  var _ sql.PrimaryKeyAlterableTable = Table{}
    49  
    50  func (t Table) IndexedAccess(sql.IndexLookup) sql.IndexedTable {
    51  	panic("not implemented")
    52  }
    53  
    54  func (t Table) PreciseMatch() bool {
    55  	return true
    56  }
    57  
    58  func (t Table) IndexedPartitions(ctx *sql.Context, _ sql.IndexLookup) (sql.PartitionIter, error) {
    59  	return t.Partitions(ctx)
    60  }
    61  
    62  // Name implements the interface sql.Table.
    63  func (t Table) Name() string {
    64  	return t.name
    65  }
    66  
    67  // String implements the interface sql.Table.
    68  func (t Table) String() string {
    69  	return t.name
    70  }
    71  
    72  // Schema implements the interface sql.Table.
    73  func (t Table) Schema() sql.Schema {
    74  	createTable, err := t.getCreateTable()
    75  	if err != nil {
    76  		panic(err)
    77  	}
    78  	return createTable.Schema()
    79  }
    80  
    81  // Collation implements the interface sql.Table.
    82  func (t Table) Collation() sql.CollationID {
    83  	return sql.Collation_Default
    84  }
    85  
    86  // Pks implements sql.PrimaryKeyAlterableTable
    87  func (t Table) Pks() []sql.IndexColumn {
    88  	createTable, err := t.getCreateTable()
    89  	if err != nil {
    90  		panic(err)
    91  	}
    92  
    93  	pkSch := createTable.PkSchema()
    94  	pkCols := make([]sql.IndexColumn, len(pkSch.PkOrdinals))
    95  	for i, j := range pkSch.PkOrdinals {
    96  		col := pkSch.Schema[j]
    97  		pkCols[i] = sql.IndexColumn{Name: col.Name}
    98  	}
    99  	return pkCols
   100  }
   101  
   102  // PrimaryKeySchema implements sql.PrimaryKeyAlterableTable
   103  func (t Table) PrimaryKeySchema() sql.PrimaryKeySchema {
   104  	createTable, err := t.getCreateTable()
   105  	if err != nil {
   106  		panic(err)
   107  	}
   108  	return createTable.PkSchema()
   109  }
   110  
   111  // Partitions implements the interface sql.Table.
   112  func (t Table) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
   113  	return &tablePartitionIter{}, nil
   114  }
   115  
   116  // PartitionRows implements the interface sql.Table.
   117  func (t Table) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
   118  	return t.db.shim.Query(t.db.name, fmt.Sprintf("SELECT * FROM `%s`;", t.name))
   119  }
   120  
   121  // Inserter implements the interface sql.InsertableTable.
   122  func (t Table) Inserter(ctx *sql.Context) sql.RowInserter {
   123  	return &tableEditor{t, t.Schema()}
   124  }
   125  
   126  // Updater implements the interface sql.UpdatableTable.
   127  func (t Table) Updater(ctx *sql.Context) sql.RowUpdater {
   128  	return &tableEditor{t, t.Schema()}
   129  }
   130  
   131  // Deleter implements the interface sql.DeletableTable.
   132  func (t Table) Deleter(ctx *sql.Context) sql.RowDeleter {
   133  	return &tableEditor{t, t.Schema()}
   134  }
   135  
   136  // Replacer implements the interface sql.ReplaceableTable.
   137  func (t Table) Replacer(ctx *sql.Context) sql.RowReplacer {
   138  	return &tableEditor{t, t.Schema()}
   139  }
   140  
   141  // Truncate implements the interface sql.TruncateableTable.
   142  func (t Table) Truncate(ctx *sql.Context) (int, error) {
   143  	rows, err := t.db.shim.QueryRows(t.db.name, fmt.Sprintf("SELECT COUNT(*) FROM `%s`;", t.name))
   144  	if err != nil {
   145  		return 0, err
   146  	}
   147  	rowCount, _, err := types.Int64.Convert(rows[0][0])
   148  	if err != nil {
   149  		return 0, err
   150  	}
   151  	err = t.db.shim.Exec("", fmt.Sprintf("TRUNCATE TABLE `%s`;", t.name))
   152  	return int(rowCount.(int64)), err
   153  }
   154  
   155  // AddColumn implements the interface sql.AlterableTable.
   156  func (t Table) AddColumn(ctx *sql.Context, column *sql.Column, order *sql.ColumnOrder) error {
   157  	statement := fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN `%s` %s;", t.name, column.Name, strings.ToUpper(column.Type.String()))
   158  	if !column.Nullable {
   159  		statement = fmt.Sprintf("%s NOT NULL", statement)
   160  	}
   161  	if column.AutoIncrement {
   162  		statement = fmt.Sprintf("%s AUTO_INCREMENT", statement)
   163  	}
   164  	if column.Default != nil {
   165  		statement = fmt.Sprintf("%s DEFAULT %s", statement, column.Default.String())
   166  	}
   167  	if column.Comment != "" {
   168  		statement = fmt.Sprintf("%s COMMENT '%s'", statement, column.Comment)
   169  	}
   170  	if order != nil {
   171  		if order.First {
   172  			statement = fmt.Sprintf("%s FIRST", statement)
   173  		} else if len(order.AfterColumn) > 0 {
   174  			statement = fmt.Sprintf("%s AFTER `%s`", statement, order.AfterColumn)
   175  		}
   176  	}
   177  	return t.db.shim.Exec(t.db.name, statement)
   178  }
   179  
   180  // DropColumn implements the interface sql.AlterableTable.
   181  func (t Table) DropColumn(ctx *sql.Context, columnName string) error {
   182  	return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN `%s`;", t.name, columnName))
   183  }
   184  
   185  // ModifyColumn implements the interface sql.AlterableTable.
   186  func (t Table) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Column, order *sql.ColumnOrder) error {
   187  	statement := fmt.Sprintf("ALTER TABLE `%s` CHANGE COLUMN `%s` `%s` %s;", t.name, columnName, column.Name, strings.ToUpper(column.Type.String()))
   188  	if !column.Nullable {
   189  		statement = fmt.Sprintf("%s NOT NULL", statement)
   190  	}
   191  	if column.AutoIncrement {
   192  		statement = fmt.Sprintf("%s AUTO_INCREMENT", statement)
   193  	}
   194  	if column.Default != nil {
   195  		statement = fmt.Sprintf("%s DEFAULT %s", statement, column.Default.String())
   196  	}
   197  	if column.Comment != "" {
   198  		statement = fmt.Sprintf("%s COMMENT '%s'", statement, column.Comment)
   199  	}
   200  	if order != nil {
   201  		if order.First {
   202  			statement = fmt.Sprintf("%s FIRST", statement)
   203  		} else if len(order.AfterColumn) > 0 {
   204  			statement = fmt.Sprintf("%s AFTER `%s`", statement, order.AfterColumn)
   205  		}
   206  	}
   207  	return t.db.shim.Exec(t.db.name, statement)
   208  }
   209  
   210  // CreateIndex implements the interface sql.IndexAlterableTable.
   211  func (t Table) CreateIndex(ctx *sql.Context, idx sql.IndexDef) error {
   212  	statement := "CREATE"
   213  	switch idx.Constraint {
   214  	case sql.IndexConstraint_Unique:
   215  		statement += " UNIQUE INDEX"
   216  	case sql.IndexConstraint_Fulltext:
   217  		statement += " FULLTEXT INDEX"
   218  	case sql.IndexConstraint_Spatial:
   219  		statement += " SPATIAL INDEX"
   220  	default:
   221  		statement += " INDEX"
   222  	}
   223  	idxColumnNames := make([]string, len(idx.Columns))
   224  	for i, column := range idx.Columns {
   225  		idxColumnNames[i] = column.Name
   226  	}
   227  	if len(idx.Name) == 0 {
   228  		idx.Name = randString(10)
   229  	}
   230  	statement = fmt.Sprintf("%s `%s` ON `%s` (`%s`)", statement, idx.Name, t.name, strings.Join(idxColumnNames, "`,`"))
   231  	if len(idx.Comment) > 0 {
   232  		statement = fmt.Sprintf("%s COMMENT '%s'", statement, strings.ReplaceAll(idx.Comment, "'", `\'`))
   233  	}
   234  	return t.db.shim.Exec(t.db.name, statement)
   235  }
   236  
   237  // DropIndex implements the interface sql.IndexAlterableTable.
   238  func (t Table) DropIndex(ctx *sql.Context, indexName string) error {
   239  	return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` DROP INDEX `%s`;", t.name, indexName))
   240  }
   241  
   242  // RenameIndex implements the interface sql.IndexAlterableTable.
   243  func (t Table) RenameIndex(ctx *sql.Context, fromIndexName string, toIndexName string) error {
   244  	return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` RENAME INDEX `%s` TO `%s`;", t.name, fromIndexName, toIndexName))
   245  }
   246  
   247  // GetIndexes implements the interface sql.IndexedTable.
   248  func (t Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
   249  	//TODO: add this along with some kind of index implementation
   250  	return nil, nil
   251  }
   252  
   253  // GetDeclaredForeignKeys implements the interface sql.ForeignKeyTable.
   254  func (t Table) GetDeclaredForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
   255  	//TODO: add this
   256  	return nil, nil
   257  }
   258  
   259  // GetReferencedForeignKeys implements the interface sql.ForeignKeyTable.
   260  func (t Table) GetReferencedForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
   261  	//TODO: add this
   262  	return nil, nil
   263  }
   264  
   265  // AddForeignKey implements the interface sql.ForeignKeyTable.
   266  func (t Table) AddForeignKey(ctx *sql.Context, fk sql.ForeignKeyConstraint) error {
   267  	constraint := ""
   268  	if len(fk.Name) > 0 {
   269  		constraint = fmt.Sprintf(" CONSTRAINT `%s`", fk.Name)
   270  	}
   271  	onDeleteStr := ""
   272  	if fk.OnDelete != sql.ForeignKeyReferentialAction_DefaultAction {
   273  		onDeleteStr = fmt.Sprintf(" ON DELETE %s", string(fk.OnDelete))
   274  	}
   275  	onUpdateStr := ""
   276  	if fk.OnUpdate != sql.ForeignKeyReferentialAction_DefaultAction {
   277  		onUpdateStr = fmt.Sprintf(" ON UPDATE %s", string(fk.OnUpdate))
   278  	}
   279  	return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s`.`%s` ADD%s FOREIGN KEY (`%s`) REFERENCES `%s`.`%s` (`%s`)%s%s;",
   280  		fk.Database, t.name, constraint, strings.Join(fk.Columns, "`,`"), fk.ParentDatabase, fk.ParentTable,
   281  		strings.Join(fk.ParentColumns, "`,`"), onDeleteStr, onUpdateStr))
   282  }
   283  
   284  // DropForeignKey implements the interface sql.ForeignKeyTable.
   285  func (t Table) DropForeignKey(ctx *sql.Context, fkName string) error {
   286  	return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` DROP FOREIGN KEY `%s`;", t.name, fkName))
   287  }
   288  
   289  // UpdateForeignKey implements the interface sql.ForeignKeyTable.
   290  func (t Table) UpdateForeignKey(ctx *sql.Context, fkName string, fkDef sql.ForeignKeyConstraint) error {
   291  	// Will automatically be handled by MySQL
   292  	return nil
   293  }
   294  
   295  // CreateIndexForForeignKey implements the interface sql.ForeignKeyTable.
   296  func (t Table) CreateIndexForForeignKey(ctx *sql.Context, idx sql.IndexDef) error {
   297  	return nil
   298  }
   299  
   300  // SetForeignKeyResolved implements the interface sql.ForeignKeyTable.
   301  func (t Table) SetForeignKeyResolved(ctx *sql.Context, fkName string) error {
   302  	return nil
   303  }
   304  
   305  // GetForeignKeyEditor implements the interface sql.ForeignKeyTable.
   306  func (t Table) GetForeignKeyEditor(ctx *sql.Context) sql.ForeignKeyEditor {
   307  	return &tableEditor{t, t.Schema()}
   308  }
   309  
   310  // CreateCheck implements the interface sql.CheckAlterableTable.
   311  func (t Table) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error {
   312  	statement := fmt.Sprintf("ALTER TABLE `%s` ADD", t.name)
   313  	if len(check.Name) > 0 {
   314  		statement = fmt.Sprintf("%s CONSTRAINT `%s`", statement, check.Name)
   315  	}
   316  	statement = fmt.Sprintf("%s CHECK (%s)", statement, check.CheckExpression)
   317  	if !check.Enforced {
   318  		statement = fmt.Sprintf("%s NOT ENFORCED", statement)
   319  	}
   320  	return t.db.shim.Exec(t.db.name, statement)
   321  }
   322  
   323  // DropCheck implements the interface sql.CheckAlterableTable.
   324  func (t Table) DropCheck(ctx *sql.Context, chName string) error {
   325  	return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` DROP CHECK `%s`;", t.name, chName))
   326  }
   327  
   328  // GetChecks implements the interface sql.CheckTable.
   329  func (t Table) GetChecks(ctx *sql.Context) ([]sql.CheckDefinition, error) {
   330  	//TODO: add this
   331  	return nil, nil
   332  }
   333  
   334  // Close implements the interface sql.AutoIncrementSetter.
   335  func (t Table) Close(ctx *sql.Context) error {
   336  	return nil
   337  }
   338  
   339  // DataLength implements the interface sql.StatisticsTable.
   340  func (t Table) DataLength(ctx *sql.Context) (uint64, error) {
   341  	// SELECT * FROM information_schema.TABLES WHERE (TABLE_SCHEMA = 'sys') AND (TABLE_NAME = 'test');
   342  	rows, err := t.db.shim.QueryRows(t.db.name, fmt.Sprintf("SELECT COUNT(*) FROM `%s`;", t.name))
   343  	if err != nil {
   344  		return 0, err
   345  	}
   346  	rowCount, _, err := types.Uint64.Convert(rows[0][0])
   347  	if err != nil {
   348  		return 0, err
   349  	}
   350  	return rowCount.(uint64), nil
   351  }
   352  
   353  // Cardinality implements the interface sql.StatisticsTable.
   354  func (t Table) RowCount(ctx *sql.Context) (uint64, bool, error) {
   355  	return 0, false, nil
   356  }
   357  
   358  // CreatePrimaryKey implements the interface sql.PrimaryKeyAlterableTable.
   359  func (t Table) CreatePrimaryKey(ctx *sql.Context, columns []sql.IndexColumn) error {
   360  	pkNames := make([]string, len(columns))
   361  	for i, column := range columns {
   362  		pkNames[i] = column.Name
   363  	}
   364  	return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` ADD PRIMARY KEY (`%s`);", t.name, strings.Join(pkNames, "`,`")))
   365  }
   366  
   367  // DropPrimaryKey implements the interface sql.PrimaryKeyAlterableTable.
   368  func (t Table) DropPrimaryKey(ctx *sql.Context) error {
   369  	return t.db.shim.Exec(t.db.name, fmt.Sprintf("ALTER TABLE `%s` DROP PRIMARY KEY;", t.name))
   370  }
   371  
   372  // getCreateTable returns this table as a CreateTable node.
   373  func (t Table) getCreateTable() (*plan.CreateTable, error) {
   374  	rows, err := t.db.shim.QueryRows(t.db.name, fmt.Sprintf("SHOW CREATE TABLE `%s`;", t.name))
   375  	if err != nil {
   376  		return nil, err
   377  	}
   378  	if len(rows) == 0 || len(rows[0]) == 0 {
   379  		return nil, sql.ErrTableNotFound.New(t.name)
   380  	}
   381  	// TODO add catalog
   382  	createTableNode, err := planbuilder.Parse(sql.NewEmptyContext(), sql.MapCatalog{Tables: map[string]sql.Table{t.name: t}}, rows[0][1].(string))
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  	return createTableNode.(*plan.CreateTable), nil
   387  }
   388  
   389  // randString returns a random string of the given length.
   390  // Retrieved from https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go
   391  func randString(n int) string {
   392  	const letterIdxBits = 6
   393  	const letterIdxMask = 1<<letterIdxBits - 1
   394  	const letterIdxMax = 63 / letterIdxBits
   395  	const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
   396  	b := make([]byte, n)
   397  	// A rand.Int63() generates 63 random bits, enough for letterIdxMax letters!
   398  	for i, cache, remain := n-1, rand.Int63(), letterIdxMax; i >= 0; {
   399  		if remain == 0 {
   400  			cache, remain = rand.Int63(), letterIdxMax
   401  		}
   402  		if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
   403  			b[i] = letterBytes[idx]
   404  			i--
   405  		}
   406  		cache >>= letterIdxBits
   407  		remain--
   408  	}
   409  
   410  	return string(b)
   411  }