github.com/dolthub/go-mysql-server@v0.18.0/enginetest/mysqlshim/database.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysqlshim
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"strings"
    21  	"time"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  )
    25  
    26  // Database represents a database for a local MySQL server.
    27  type Database struct {
    28  	shim *MySQLShim
    29  	name string
    30  }
    31  
    32  var _ sql.Database = Database{}
    33  var _ sql.TableCreator = Database{}
    34  var _ sql.TableDropper = Database{}
    35  var _ sql.TableRenamer = Database{}
    36  var _ sql.TriggerDatabase = Database{}
    37  var _ sql.StoredProcedureDatabase = Database{}
    38  var _ sql.ViewDatabase = Database{}
    39  var _ sql.EventDatabase = Database{}
    40  
    41  // Name implements the interface sql.Database.
    42  func (d Database) Name() string {
    43  	return d.name
    44  }
    45  
    46  // GetTableInsensitive implements the interface sql.Database.
    47  func (d Database) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) {
    48  	tables, err := d.GetTableNames(ctx)
    49  	if err != nil {
    50  		return nil, false, err
    51  	}
    52  	lowerName := strings.ToLower(tblName)
    53  	for _, readName := range tables {
    54  		if lowerName == strings.ToLower(readName) {
    55  			return Table{d, readName}, true, nil
    56  		}
    57  	}
    58  	return nil, false, nil
    59  }
    60  
    61  // GetTableNames implements the interface sql.Database.
    62  func (d Database) GetTableNames(ctx *sql.Context) ([]string, error) {
    63  	rows, err := d.shim.Query(d.name, "SHOW TABLES;")
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	defer rows.Close(ctx)
    68  	var tableNames []string
    69  	var row sql.Row
    70  	for row, err = rows.Next(ctx); err == nil; row, err = rows.Next(ctx) {
    71  		tableNames = append(tableNames, row[0].(string))
    72  	}
    73  	if err != io.EOF {
    74  		return nil, err
    75  	}
    76  	return tableNames, nil
    77  }
    78  
    79  // CreateTable implements the interface sql.TableCreator.
    80  func (d Database) CreateTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string) error {
    81  	colStmts := make([]string, len(schema.Schema))
    82  	var primaryKeyCols []string
    83  	for i, col := range schema.Schema {
    84  		stmt := fmt.Sprintf("  `%s` %s", col.Name, col.Type.String())
    85  		if !col.Nullable {
    86  			stmt = fmt.Sprintf("%s NOT NULL", stmt)
    87  		}
    88  		if col.AutoIncrement {
    89  			stmt = fmt.Sprintf("%s AUTO_INCREMENT", stmt)
    90  		}
    91  		if col.Default != nil {
    92  			stmt = fmt.Sprintf("%s DEFAULT %s", stmt, col.Default.String())
    93  		}
    94  		if col.Comment != "" {
    95  			stmt = fmt.Sprintf("%s COMMENT '%s'", stmt, col.Comment)
    96  		}
    97  		if col.PrimaryKey {
    98  			primaryKeyCols = append(primaryKeyCols, col.Name)
    99  		}
   100  		colStmts[i] = stmt
   101  	}
   102  	if len(primaryKeyCols) > 0 {
   103  		primaryKey := fmt.Sprintf("  PRIMARY KEY (`%s`)", strings.Join(primaryKeyCols, "`,`"))
   104  		colStmts = append(colStmts, primaryKey)
   105  	}
   106  	return d.shim.Exec(d.name, fmt.Sprintf("CREATE TABLE `%s` (\n%s\n) ENGINE=InnoDB DEFAULT COLLATE=%s COMMENT='%s';",
   107  		name, strings.Join(colStmts, ",\n"), sql.Collation_Default.String(), comment))
   108  }
   109  
   110  // DropTable implements the interface sql.TableDropper.
   111  func (d Database) DropTable(ctx *sql.Context, name string) error {
   112  	return d.shim.Exec(d.name, fmt.Sprintf("DROP TABLE `%s`;", name))
   113  }
   114  
   115  // RenameTable implements the interface sql.TableRenamer.
   116  func (d Database) RenameTable(ctx *sql.Context, oldName, newName string) error {
   117  	return d.shim.Exec(d.name, fmt.Sprintf("RENAME TABLE `%s` TO `%s`;", oldName, newName))
   118  }
   119  
   120  // GetTriggers implements the interface sql.TriggerDatabase.
   121  func (d Database) GetTriggers(ctx *sql.Context) ([]sql.TriggerDefinition, error) {
   122  	rows, err := d.shim.Query(d.name, "SHOW TRIGGERS;")
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	defer rows.Close(ctx)
   127  	var triggers []sql.TriggerDefinition
   128  	var row sql.Row
   129  	for row, err = rows.Next(ctx); err == nil; row, err = rows.Next(ctx) {
   130  		// Trigger, Event, Table, Statement, Timing, Created, sql_mode, ...
   131  		triggers = append(triggers, sql.TriggerDefinition{
   132  			Name: row[0].(string),
   133  			CreateStatement: fmt.Sprintf("CREATE TRIGGER `%s` %s %s ON `%s` FOR EACH ROW %s;",
   134  				row[0].(string), row[4].(string), row[1].(string), row[2].(string), row[3].(string)),
   135  			CreatedAt: time.Time{}, // TODO: time works in with doltharness
   136  		})
   137  	}
   138  	if err != io.EOF {
   139  		return nil, err
   140  	}
   141  	return triggers, nil
   142  }
   143  
   144  // CreateTrigger implements the interface sql.TriggerDatabase.
   145  func (d Database) CreateTrigger(ctx *sql.Context, definition sql.TriggerDefinition) error {
   146  	return d.shim.Exec(d.name, definition.CreateStatement)
   147  }
   148  
   149  // DropTrigger implements the interface sql.TriggerDatabase.
   150  func (d Database) DropTrigger(ctx *sql.Context, name string) error {
   151  	return d.shim.Exec(d.name, fmt.Sprintf("DROP TRIGGER `%s`;", name))
   152  }
   153  
   154  // GetStoredProcedure implements the interface sql.StoredProcedureDatabase.
   155  func (d Database) GetStoredProcedure(ctx *sql.Context, name string) (sql.StoredProcedureDetails, bool, error) {
   156  	name = strings.ToLower(name)
   157  	procedures, err := d.GetStoredProcedures(ctx)
   158  	if err != nil {
   159  		return sql.StoredProcedureDetails{}, false, err
   160  	}
   161  	for _, procedure := range procedures {
   162  		if name == strings.ToLower(procedure.Name) {
   163  			return procedure, true, nil
   164  		}
   165  	}
   166  	return sql.StoredProcedureDetails{}, false, nil
   167  }
   168  
   169  // GetStoredProcedures implements the interface sql.StoredProcedureDatabase.
   170  func (d Database) GetStoredProcedures(ctx *sql.Context) ([]sql.StoredProcedureDetails, error) {
   171  	procedures, err := d.shim.QueryRows("", fmt.Sprintf("SHOW PROCEDURE STATUS WHERE Db = '%s';", d.name))
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	storedProcedureDetails := make([]sql.StoredProcedureDetails, len(procedures))
   176  	for i, procedure := range procedures {
   177  		// Db, Name, Type, Definer, Modified, Created, Security_type, Comment, ...
   178  		procedureStatement, err := d.shim.QueryRows("", fmt.Sprintf("SHOW CREATE PROCEDURE `%s`.`%s`;", d.name, procedure[1]))
   179  		if err != nil {
   180  			return nil, err
   181  		}
   182  		// Procedure, sql_mode, Create Procedure, ...
   183  		storedProcedureDetails[i] = sql.StoredProcedureDetails{
   184  			Name:            procedureStatement[0][0].(string),
   185  			CreateStatement: procedureStatement[0][2].(string),
   186  			CreatedAt:       time.Time{}, // these should be added someday
   187  			ModifiedAt:      time.Time{},
   188  		}
   189  	}
   190  	return storedProcedureDetails, nil
   191  }
   192  
   193  // SaveStoredProcedure implements the interface sql.StoredProcedureDatabase.
   194  func (d Database) SaveStoredProcedure(ctx *sql.Context, spd sql.StoredProcedureDetails) error {
   195  	return d.shim.Exec(d.name, spd.CreateStatement)
   196  }
   197  
   198  // DropStoredProcedure implements the interface sql.StoredProcedureDatabase.
   199  func (d Database) DropStoredProcedure(ctx *sql.Context, name string) error {
   200  	return d.shim.Exec(d.name, fmt.Sprintf("DROP PROCEDURE `%s`;", name))
   201  }
   202  
   203  // GetEvent implements sql.EventDatabase
   204  func (d Database) GetEvent(ctx *sql.Context, name string) (sql.EventDefinition, bool, error) {
   205  	name = strings.ToLower(name)
   206  	events, _, err := d.GetEvents(ctx)
   207  	if err != nil {
   208  		return sql.EventDefinition{}, false, err
   209  	}
   210  	for _, event := range events {
   211  		if name == strings.ToLower(event.Name) {
   212  			return event, true, nil
   213  		}
   214  	}
   215  	return sql.EventDefinition{}, false, nil
   216  }
   217  
   218  // GetEvents implements sql.EventDatabase
   219  func (d Database) GetEvents(_ *sql.Context) ([]sql.EventDefinition, interface{}, error) {
   220  	events, err := d.shim.QueryRows("", fmt.Sprintf("SHOW EVENTS WHERE Db = '%s';", d.name))
   221  	if err != nil {
   222  		return nil, nil, err
   223  	}
   224  	eventDefinition := make([]sql.EventDefinition, len(events))
   225  	for i, event := range events {
   226  		// Db, Name, Definer, Time Zone, Type, ...
   227  		eventStmt, err := d.shim.QueryRows("", fmt.Sprintf("SHOW CREATE EVENT `%s`.`%s`;", d.name, event[1]))
   228  		if err != nil {
   229  			return nil, nil, err
   230  		}
   231  		// Event, sql_mode, time_zone, Create Event, ...
   232  		eventDefinition[i] = sql.EventDefinition{
   233  			Name: eventStmt[0][0].(string),
   234  			// TODO: other fields should be added such as Created, LastAltered
   235  		}
   236  	}
   237  	// MySQL shim doesn't support event reloading, so token is always nil
   238  	return eventDefinition, nil, nil
   239  }
   240  
   241  // SaveEvent implements sql.EventDatabase
   242  func (d Database) SaveEvent(ctx *sql.Context, event sql.EventDefinition) (bool, error) {
   243  	return event.Status == sql.EventStatus_Enable.String(), d.shim.Exec(d.name, event.CreateEventStatement())
   244  }
   245  
   246  // DropEvent implements sql.EventDatabase
   247  func (d Database) DropEvent(ctx *sql.Context, name string) error {
   248  	return d.shim.Exec(d.name, fmt.Sprintf("DROP EVENT `%s`;", name))
   249  }
   250  
   251  // UpdateEvent implements sql.EventDatabase
   252  func (d Database) UpdateEvent(_ *sql.Context, originalName string, event sql.EventDefinition) (bool, error) {
   253  	err := d.shim.Exec(d.name, fmt.Sprintf("DROP EVENT `%s`;", originalName))
   254  	if err != nil {
   255  		return false, err
   256  	}
   257  	return event.Status == sql.EventStatus_Enable.String(), d.shim.Exec(d.name, event.CreateEventStatement())
   258  }
   259  
   260  // NeedsToReloadEvents implements sql.EventDatabase
   261  func (d Database) NeedsToReloadEvents(_ *sql.Context, _ interface{}) (bool, error) {
   262  	// mysqlshim does not support event reloading
   263  	return false, nil
   264  }
   265  
   266  // UpdateLastExecuted implements sql.EventDatabase
   267  func (d Database) UpdateLastExecuted(ctx *sql.Context, eventName string, lastExecuted time.Time) error {
   268  	return nil
   269  }
   270  
   271  // CreateView implements the interface sql.ViewDatabase.
   272  func (d Database) CreateView(ctx *sql.Context, name string, selectStatement, createViewStmt string) error {
   273  	return d.shim.Exec(d.name, createViewStmt)
   274  }
   275  
   276  // DropView implements the interface sql.ViewDatabase.
   277  func (d Database) DropView(ctx *sql.Context, name string) error {
   278  	return d.shim.Exec(d.name, fmt.Sprintf("DROP VIEW `%s`;", name))
   279  }
   280  
   281  // GetViewDefinition implements the interface sql.ViewDatabase.
   282  func (d Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.ViewDefinition, bool, error) {
   283  	views, err := d.AllViews(ctx)
   284  	if err != nil {
   285  		return sql.ViewDefinition{}, false, err
   286  	}
   287  	lowerName := strings.ToLower(viewName)
   288  	for _, view := range views {
   289  		if lowerName == strings.ToLower(view.Name) {
   290  			return view, true, nil
   291  		}
   292  	}
   293  	return sql.ViewDefinition{}, false, nil
   294  }
   295  
   296  // AllViews implements the interface sql.ViewDatabase.
   297  func (d Database) AllViews(ctx *sql.Context) ([]sql.ViewDefinition, error) {
   298  	views, err := d.shim.QueryRows("", fmt.Sprintf("SELECT * FROM information_schema.TABLES WHERE TABLE_SCHEMA = '%s' AND TABLE_TYPE = 'VIEW';", d.name))
   299  	if err != nil {
   300  		return nil, err
   301  	}
   302  	viewDefinitions := make([]sql.ViewDefinition, len(views))
   303  	for i, view := range views {
   304  		viewName := view[2].(string)
   305  		viewStatementRow, err := d.shim.QueryRows("", fmt.Sprintf("SHOW CREATE VIEW `%s`.`%s`;", d.name, viewName))
   306  		if err != nil {
   307  			return nil, err
   308  		}
   309  		createViewStatement := viewStatementRow[0][1].(string)
   310  		viewStatement := createViewStatement[strings.Index(createViewStatement, " AS ")+4:] // not the best but works for now
   311  		viewDefinitions[i] = sql.ViewDefinition{
   312  			Name:                viewName,
   313  			TextDefinition:      viewStatement,
   314  			CreateViewStatement: createViewStatement,
   315  		}
   316  	}
   317  	return viewDefinitions, nil
   318  }