github.com/dolthub/go-mysql-server@v0.18.0/enginetest/testdata.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package enginetest 16 17 import ( 18 "testing" 19 20 "github.com/stretchr/testify/require" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/mysql_db" 24 "github.com/dolthub/go-mysql-server/sql/types" 25 ) 26 27 // wrapInTransaction runs the function given surrounded in a transaction. If the db provided doesn't implement 28 // sql.TransactionDatabase, then the function is simply run and the transaction logic is a no-op. 29 func wrapInTransaction(t *testing.T, db sql.Database, harness Harness, fn func()) { 30 ctx := NewContext(harness) 31 ctx.SetCurrentDatabase(db.Name()) 32 if privilegedDatabase, ok := db.(mysql_db.PrivilegedDatabase); ok { 33 db = privilegedDatabase.Unwrap() 34 } 35 36 ts, transactionsSupported := ctx.Session.(sql.TransactionSession) 37 38 if transactionsSupported { 39 tx, err := ts.StartTransaction(ctx, sql.ReadWrite) 40 require.NoError(t, err) 41 ctx.SetTransaction(tx) 42 } 43 44 fn() 45 46 if transactionsSupported { 47 tx := ctx.GetTransaction() 48 if tx != nil { 49 err := ts.CommitTransaction(ctx, tx) 50 require.NoError(t, err) 51 ctx.SetTransaction(nil) 52 } 53 } 54 } 55 56 func createVersionedTables(t *testing.T, harness Harness, myDb, foo sql.Database) []sql.Database { 57 var table sql.Table 58 var err error 59 60 if versionedHarness, ok := harness.(VersionedDBHarness); ok { 61 versionedDb, ok := myDb.(sql.VersionedDatabase) 62 if !ok { 63 require.Failf(t, "expected a sql.VersionedDatabase", "%T is not a sql.VersionedDatabase", myDb) 64 } 65 66 wrapInTransaction(t, myDb, harness, func() { 67 table = versionedHarness.NewTableAsOf(versionedDb, "myhistorytable", sql.NewPrimaryKeySchema(sql.Schema{ 68 {Name: "i", Type: types.Int64, Source: "myhistorytable", PrimaryKey: true}, 69 {Name: "s", Type: types.Text, Source: "myhistorytable"}, 70 }), "2019-01-01") 71 72 if err == nil { 73 InsertRows(t, NewContext(harness), mustInsertableTable(t, table), 74 sql.NewRow(int64(1), "first row, 1"), 75 sql.NewRow(int64(2), "second row, 1"), 76 sql.NewRow(int64(3), "third row, 1")) 77 } else { 78 t.Logf("Warning: could not create table %s: %s", "myhistorytable", err) 79 } 80 }) 81 82 wrapInTransaction(t, myDb, harness, func() { 83 require.NoError(t, versionedHarness.SnapshotTable(versionedDb, "myhistorytable", "2019-01-01")) 84 }) 85 86 wrapInTransaction(t, myDb, harness, func() { 87 table = versionedHarness.NewTableAsOf(versionedDb, "myhistorytable", sql.NewPrimaryKeySchema(sql.Schema{ 88 {Name: "i", Type: types.Int64, Source: "myhistorytable", PrimaryKey: true}, 89 {Name: "s", Type: types.Text, Source: "myhistorytable"}, 90 }), "2019-01-02") 91 92 if err == nil { 93 DeleteRows(t, NewContext(harness), mustDeletableTable(t, table), 94 sql.NewRow(int64(1), "first row, 1"), 95 sql.NewRow(int64(2), "second row, 1"), 96 sql.NewRow(int64(3), "third row, 1")) 97 InsertRows(t, NewContext(harness), mustInsertableTable(t, table), 98 sql.NewRow(int64(1), "first row, 2"), 99 sql.NewRow(int64(2), "second row, 2"), 100 sql.NewRow(int64(3), "third row, 2")) 101 } else { 102 t.Logf("Warning: could not create table %s: %s", "myhistorytable", err) 103 } 104 }) 105 106 wrapInTransaction(t, myDb, harness, func() { 107 require.NoError(t, versionedHarness.SnapshotTable(versionedDb, "myhistorytable", "2019-01-02")) 108 }) 109 110 wrapInTransaction(t, myDb, harness, func() { 111 table = versionedHarness.NewTableAsOf(versionedDb, "myhistorytable", sql.NewPrimaryKeySchema(sql.Schema{ 112 {Name: "i", Type: types.Int64, Source: "myhistorytable", PrimaryKey: true}, 113 {Name: "s", Type: types.Text, Source: "myhistorytable"}, 114 }), "2019-01-03") 115 116 if err == nil { 117 DeleteRows(t, NewContext(harness), mustDeletableTable(t, table), 118 sql.NewRow(int64(1), "first row, 2"), 119 sql.NewRow(int64(2), "second row, 2"), 120 sql.NewRow(int64(3), "third row, 2")) 121 column := sql.Column{Name: "c", Type: types.Text} 122 AddColumn(t, NewContext(harness), mustAlterableTable(t, table), &column) 123 InsertRows(t, NewContext(harness), mustInsertableTable(t, table), 124 sql.NewRow(int64(1), "first row, 3", "1"), 125 sql.NewRow(int64(2), "second row, 3", "2"), 126 sql.NewRow(int64(3), "third row, 3", "3")) 127 } 128 }) 129 130 wrapInTransaction(t, myDb, harness, func() { 131 table = versionedHarness.NewTableAsOf(versionedDb, "mytable", sql.NewPrimaryKeySchema(sql.Schema{ 132 {Name: "i", Type: types.Int64, Source: "mytable", PrimaryKey: true}, 133 {Name: "s", Type: types.Text, Source: "mytable"}, 134 }), nil) 135 136 if err == nil { 137 InsertRows(t, NewContext(harness), mustInsertableTable(t, table), 138 sql.NewRow(int64(1), "first row, 1"), 139 sql.NewRow(int64(2), "second row, 2"), 140 sql.NewRow(int64(3), "third row, 3")) 141 } 142 }) 143 144 wrapInTransaction(t, myDb, harness, func() { 145 require.NoError(t, versionedHarness.SnapshotTable(versionedDb, "myhistorytable", "2019-01-03")) 146 }) 147 } 148 149 return []sql.Database{myDb, foo} 150 } 151 152 // CreateVersionedTestData uses the provided harness to create test tables and data for many of the other tests. 153 func CreateVersionedTestData(t *testing.T, harness VersionedDBHarness) []sql.Database { 154 // TODO: only create mydb here 155 dbs := harness.NewDatabases("mydb", "foo") 156 return createVersionedTables(t, harness, dbs[0], dbs[1]) 157 } 158 159 func mustInsertableTable(t *testing.T, table sql.Table) sql.InsertableTable { 160 insertable, ok := table.(sql.InsertableTable) 161 require.True(t, ok, "Table must implement sql.InsertableTable") 162 return insertable 163 } 164 165 func mustDeletableTable(t *testing.T, table sql.Table) sql.DeletableTable { 166 deletable, ok := table.(sql.DeletableTable) 167 require.True(t, ok, "Table must implement sql.DeletableTable") 168 return deletable 169 } 170 171 func mustAlterableTable(t *testing.T, table sql.Table) sql.AlterableTable { 172 alterable, ok := table.(sql.AlterableTable) 173 require.True(t, ok, "Table must implement sql.AlterableTable") 174 return alterable 175 } 176 177 func InsertRows(t *testing.T, ctx *sql.Context, table sql.InsertableTable, rows ...sql.Row) { 178 t.Helper() 179 180 inserter := table.Inserter(ctx) 181 for _, r := range rows { 182 require.NoError(t, inserter.Insert(ctx, r)) 183 } 184 err := inserter.Close(ctx) 185 require.NoError(t, err) 186 } 187 188 // AddColumn adds a column to the specified table 189 func AddColumn(t *testing.T, ctx *sql.Context, table sql.AlterableTable, column *sql.Column) { 190 t.Helper() 191 192 err := table.AddColumn(ctx, column, nil) 193 require.NoError(t, err) 194 } 195 196 func DeleteRows(t *testing.T, ctx *sql.Context, table sql.DeletableTable, rows ...sql.Row) { 197 t.Helper() 198 199 deleter := table.Deleter(ctx) 200 for _, r := range rows { 201 if err := deleter.Delete(ctx, r); err != nil { 202 require.True(t, sql.ErrDeleteRowNotFound.Is(err)) 203 } 204 } 205 require.NoError(t, deleter.Close(ctx)) 206 }