github.com/dolthub/go-mysql-server@v0.18.0/enginetest/testdata.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package enginetest
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/mysql_db"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  )
    26  
    27  // wrapInTransaction runs the function given surrounded in a transaction. If the db provided doesn't implement
    28  // sql.TransactionDatabase, then the function is simply run and the transaction logic is a no-op.
    29  func wrapInTransaction(t *testing.T, db sql.Database, harness Harness, fn func()) {
    30  	ctx := NewContext(harness)
    31  	ctx.SetCurrentDatabase(db.Name())
    32  	if privilegedDatabase, ok := db.(mysql_db.PrivilegedDatabase); ok {
    33  		db = privilegedDatabase.Unwrap()
    34  	}
    35  
    36  	ts, transactionsSupported := ctx.Session.(sql.TransactionSession)
    37  
    38  	if transactionsSupported {
    39  		tx, err := ts.StartTransaction(ctx, sql.ReadWrite)
    40  		require.NoError(t, err)
    41  		ctx.SetTransaction(tx)
    42  	}
    43  
    44  	fn()
    45  
    46  	if transactionsSupported {
    47  		tx := ctx.GetTransaction()
    48  		if tx != nil {
    49  			err := ts.CommitTransaction(ctx, tx)
    50  			require.NoError(t, err)
    51  			ctx.SetTransaction(nil)
    52  		}
    53  	}
    54  }
    55  
    56  func createVersionedTables(t *testing.T, harness Harness, myDb, foo sql.Database) []sql.Database {
    57  	var table sql.Table
    58  	var err error
    59  
    60  	if versionedHarness, ok := harness.(VersionedDBHarness); ok {
    61  		versionedDb, ok := myDb.(sql.VersionedDatabase)
    62  		if !ok {
    63  			require.Failf(t, "expected a sql.VersionedDatabase", "%T is not a sql.VersionedDatabase", myDb)
    64  		}
    65  
    66  		wrapInTransaction(t, myDb, harness, func() {
    67  			table = versionedHarness.NewTableAsOf(versionedDb, "myhistorytable", sql.NewPrimaryKeySchema(sql.Schema{
    68  				{Name: "i", Type: types.Int64, Source: "myhistorytable", PrimaryKey: true},
    69  				{Name: "s", Type: types.Text, Source: "myhistorytable"},
    70  			}), "2019-01-01")
    71  
    72  			if err == nil {
    73  				InsertRows(t, NewContext(harness), mustInsertableTable(t, table),
    74  					sql.NewRow(int64(1), "first row, 1"),
    75  					sql.NewRow(int64(2), "second row, 1"),
    76  					sql.NewRow(int64(3), "third row, 1"))
    77  			} else {
    78  				t.Logf("Warning: could not create table %s: %s", "myhistorytable", err)
    79  			}
    80  		})
    81  
    82  		wrapInTransaction(t, myDb, harness, func() {
    83  			require.NoError(t, versionedHarness.SnapshotTable(versionedDb, "myhistorytable", "2019-01-01"))
    84  		})
    85  
    86  		wrapInTransaction(t, myDb, harness, func() {
    87  			table = versionedHarness.NewTableAsOf(versionedDb, "myhistorytable", sql.NewPrimaryKeySchema(sql.Schema{
    88  				{Name: "i", Type: types.Int64, Source: "myhistorytable", PrimaryKey: true},
    89  				{Name: "s", Type: types.Text, Source: "myhistorytable"},
    90  			}), "2019-01-02")
    91  
    92  			if err == nil {
    93  				DeleteRows(t, NewContext(harness), mustDeletableTable(t, table),
    94  					sql.NewRow(int64(1), "first row, 1"),
    95  					sql.NewRow(int64(2), "second row, 1"),
    96  					sql.NewRow(int64(3), "third row, 1"))
    97  				InsertRows(t, NewContext(harness), mustInsertableTable(t, table),
    98  					sql.NewRow(int64(1), "first row, 2"),
    99  					sql.NewRow(int64(2), "second row, 2"),
   100  					sql.NewRow(int64(3), "third row, 2"))
   101  			} else {
   102  				t.Logf("Warning: could not create table %s: %s", "myhistorytable", err)
   103  			}
   104  		})
   105  
   106  		wrapInTransaction(t, myDb, harness, func() {
   107  			require.NoError(t, versionedHarness.SnapshotTable(versionedDb, "myhistorytable", "2019-01-02"))
   108  		})
   109  
   110  		wrapInTransaction(t, myDb, harness, func() {
   111  			table = versionedHarness.NewTableAsOf(versionedDb, "myhistorytable", sql.NewPrimaryKeySchema(sql.Schema{
   112  				{Name: "i", Type: types.Int64, Source: "myhistorytable", PrimaryKey: true},
   113  				{Name: "s", Type: types.Text, Source: "myhistorytable"},
   114  			}), "2019-01-03")
   115  
   116  			if err == nil {
   117  				DeleteRows(t, NewContext(harness), mustDeletableTable(t, table),
   118  					sql.NewRow(int64(1), "first row, 2"),
   119  					sql.NewRow(int64(2), "second row, 2"),
   120  					sql.NewRow(int64(3), "third row, 2"))
   121  				column := sql.Column{Name: "c", Type: types.Text}
   122  				AddColumn(t, NewContext(harness), mustAlterableTable(t, table), &column)
   123  				InsertRows(t, NewContext(harness), mustInsertableTable(t, table),
   124  					sql.NewRow(int64(1), "first row, 3", "1"),
   125  					sql.NewRow(int64(2), "second row, 3", "2"),
   126  					sql.NewRow(int64(3), "third row, 3", "3"))
   127  			}
   128  		})
   129  
   130  		wrapInTransaction(t, myDb, harness, func() {
   131  			table = versionedHarness.NewTableAsOf(versionedDb, "mytable", sql.NewPrimaryKeySchema(sql.Schema{
   132  				{Name: "i", Type: types.Int64, Source: "mytable", PrimaryKey: true},
   133  				{Name: "s", Type: types.Text, Source: "mytable"},
   134  			}), nil)
   135  
   136  			if err == nil {
   137  				InsertRows(t, NewContext(harness), mustInsertableTable(t, table),
   138  					sql.NewRow(int64(1), "first row, 1"),
   139  					sql.NewRow(int64(2), "second row, 2"),
   140  					sql.NewRow(int64(3), "third row, 3"))
   141  			}
   142  		})
   143  
   144  		wrapInTransaction(t, myDb, harness, func() {
   145  			require.NoError(t, versionedHarness.SnapshotTable(versionedDb, "myhistorytable", "2019-01-03"))
   146  		})
   147  	}
   148  
   149  	return []sql.Database{myDb, foo}
   150  }
   151  
   152  // CreateVersionedTestData uses the provided harness to create test tables and data for many of the other tests.
   153  func CreateVersionedTestData(t *testing.T, harness VersionedDBHarness) []sql.Database {
   154  	// TODO: only create mydb here
   155  	dbs := harness.NewDatabases("mydb", "foo")
   156  	return createVersionedTables(t, harness, dbs[0], dbs[1])
   157  }
   158  
   159  func mustInsertableTable(t *testing.T, table sql.Table) sql.InsertableTable {
   160  	insertable, ok := table.(sql.InsertableTable)
   161  	require.True(t, ok, "Table must implement sql.InsertableTable")
   162  	return insertable
   163  }
   164  
   165  func mustDeletableTable(t *testing.T, table sql.Table) sql.DeletableTable {
   166  	deletable, ok := table.(sql.DeletableTable)
   167  	require.True(t, ok, "Table must implement sql.DeletableTable")
   168  	return deletable
   169  }
   170  
   171  func mustAlterableTable(t *testing.T, table sql.Table) sql.AlterableTable {
   172  	alterable, ok := table.(sql.AlterableTable)
   173  	require.True(t, ok, "Table must implement sql.AlterableTable")
   174  	return alterable
   175  }
   176  
   177  func InsertRows(t *testing.T, ctx *sql.Context, table sql.InsertableTable, rows ...sql.Row) {
   178  	t.Helper()
   179  
   180  	inserter := table.Inserter(ctx)
   181  	for _, r := range rows {
   182  		require.NoError(t, inserter.Insert(ctx, r))
   183  	}
   184  	err := inserter.Close(ctx)
   185  	require.NoError(t, err)
   186  }
   187  
   188  // AddColumn adds a column to the specified table
   189  func AddColumn(t *testing.T, ctx *sql.Context, table sql.AlterableTable, column *sql.Column) {
   190  	t.Helper()
   191  
   192  	err := table.AddColumn(ctx, column, nil)
   193  	require.NoError(t, err)
   194  }
   195  
   196  func DeleteRows(t *testing.T, ctx *sql.Context, table sql.DeletableTable, rows ...sql.Row) {
   197  	t.Helper()
   198  
   199  	deleter := table.Deleter(ctx)
   200  	for _, r := range rows {
   201  		if err := deleter.Delete(ctx, r); err != nil {
   202  			require.True(t, sql.ErrDeleteRowNotFound.Is(err))
   203  		}
   204  	}
   205  	require.NoError(t, deleter.Close(ctx))
   206  }