github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/enginetest/dolt_harness.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package enginetest
    16  
    17  import (
    18  	"context"
    19  	"runtime"
    20  	"strings"
    21  	"testing"
    22  
    23  	"github.com/dolthub/go-mysql-server/enginetest"
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/stretchr/testify/require"
    26  
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions"
    30  )
    31  
    32  type DoltHarness struct {
    33  	t                   *testing.T
    34  	session             *sqle.DoltSession
    35  	transactionsEnabled bool
    36  	databases           []sqle.Database
    37  	parallelism         int
    38  	skippedQueries      []string
    39  }
    40  
    41  var _ enginetest.Harness = (*DoltHarness)(nil)
    42  var _ enginetest.SkippingHarness = (*DoltHarness)(nil)
    43  var _ enginetest.IndexHarness = (*DoltHarness)(nil)
    44  var _ enginetest.VersionedDBHarness = (*DoltHarness)(nil)
    45  var _ enginetest.ForeignKeyHarness = (*DoltHarness)(nil)
    46  var _ enginetest.KeylessTableHarness = (*DoltHarness)(nil)
    47  
    48  func newDoltHarness(t *testing.T) *DoltHarness {
    49  	session, err := sqle.NewDoltSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), "test", "email@test.com")
    50  	require.NoError(t, err)
    51  	return &DoltHarness{
    52  		t:              t,
    53  		session:        session,
    54  		skippedQueries: defaultSkippedQueries,
    55  	}
    56  }
    57  
    58  // withTransactionsEnabled returns a copy of this harness with transactions enabled or not for all sessions
    59  func (d DoltHarness) withTransactionsEnabled(enabled bool) *DoltHarness {
    60  	d.transactionsEnabled = enabled
    61  	d.setTransactionSessionVar(d.session, enabled)
    62  	return &d
    63  }
    64  
    65  func (d DoltHarness) setTransactionSessionVar(session *sqle.DoltSession, enabled bool) {
    66  	err := session.SetSessionVariable(sql.NewEmptyContext(), sqle.TransactionsEnabledSysVar, enabled)
    67  	require.NoError(d.t, err)
    68  }
    69  
    70  var defaultSkippedQueries = []string{
    71  	"show variables",           // we set extra variables
    72  	"show create table fk_tbl", // we create an extra key for the FK that vanilla gms does not
    73  	"show indexes from",        // we create / expose extra indexes (for foreign keys)
    74  	"json_arrayagg",            // TODO: aggregation ordering
    75  	"json_objectagg",           // TODO: aggregation ordering
    76  	"typestable",               // Bit type isn't working?
    77  	"dolt_commit_diff_",        // see broken queries in `dolt_system_table_queries.go`
    78  }
    79  
    80  // WithParallelism returns a copy of the harness with parallelism set to the given number of threads. A value of 0 or
    81  // less means to use the system parallelism settings.
    82  func (d DoltHarness) WithParallelism(parallelism int) *DoltHarness {
    83  	d.parallelism = parallelism
    84  	return &d
    85  }
    86  
    87  // WithSkippedQueries returns a copy of the harness with the given queries skipped
    88  func (d DoltHarness) WithSkippedQueries(queries []string) *DoltHarness {
    89  	d.skippedQueries = queries
    90  	return &d
    91  }
    92  
    93  // SkipQueryTest returns whether to skip a query
    94  func (d *DoltHarness) SkipQueryTest(query string) bool {
    95  	lowerQuery := strings.ToLower(query)
    96  	for _, skipped := range d.skippedQueries {
    97  		if strings.Contains(lowerQuery, strings.ToLower(skipped)) {
    98  			return true
    99  		}
   100  	}
   101  
   102  	return false
   103  }
   104  
   105  func (d *DoltHarness) Parallelism() int {
   106  	if d.parallelism <= 0 {
   107  
   108  		// always test with some parallelism
   109  		parallelism := runtime.NumCPU()
   110  
   111  		if parallelism <= 1 {
   112  			parallelism = 2
   113  		}
   114  
   115  		return parallelism
   116  	}
   117  
   118  	return d.parallelism
   119  }
   120  
   121  func (d *DoltHarness) NewContext() *sql.Context {
   122  	return sql.NewContext(
   123  		context.Background(),
   124  		sql.WithSession(d.session))
   125  }
   126  
   127  func (d DoltHarness) NewSession() *sql.Context {
   128  	session, err := sqle.NewDoltSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), "test", "email@test.com")
   129  	require.NoError(d.t, err)
   130  
   131  	d.setTransactionSessionVar(session, d.transactionsEnabled)
   132  
   133  	ctx := sql.NewContext(
   134  		context.Background(),
   135  		sql.WithSession(session))
   136  
   137  	for _, db := range d.databases {
   138  		err := session.AddDB(ctx, db, db.DbData())
   139  		require.NoError(d.t, err)
   140  	}
   141  
   142  	return ctx
   143  }
   144  
   145  func (d *DoltHarness) SupportsNativeIndexCreation() bool {
   146  	return true
   147  }
   148  
   149  func (d *DoltHarness) SupportsForeignKeys() bool {
   150  	return true
   151  }
   152  
   153  func (d *DoltHarness) SupportsKeylessTables() bool {
   154  	return true
   155  }
   156  
   157  func (d *DoltHarness) NewDatabase(name string) sql.Database {
   158  	return d.NewDatabases(name)[0]
   159  }
   160  
   161  func (d *DoltHarness) NewDatabases(names ...string) []sql.Database {
   162  	dEnv := dtestutils.CreateTestEnv()
   163  
   164  	// TODO: it should be safe to reuse a session with a new database, but it isn't in all cases. Particularly, if you
   165  	//  have a database that only ever receives read queries, and then you re-use its session for a new database with
   166  	//  the same name, the first write query will panic on dangling references in the noms layer. Not sure why this is
   167  	//  happening, but it only happens as a result of this test setup.
   168  	var err error
   169  	d.session, err = sqle.NewDoltSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), "test", "email@test.com")
   170  	require.NoError(d.t, err)
   171  
   172  	d.setTransactionSessionVar(d.session, d.transactionsEnabled)
   173  
   174  	var dbs []sql.Database
   175  	d.databases = nil
   176  	for _, name := range names {
   177  		db := sqle.NewDatabase(name, dEnv.DbData())
   178  		require.NoError(d.t, d.session.AddDB(enginetest.NewContext(d), db, db.DbData()))
   179  		dbs = append(dbs, db)
   180  		d.databases = append(d.databases, db)
   181  	}
   182  
   183  	return dbs
   184  }
   185  
   186  func (d *DoltHarness) NewTable(db sql.Database, name string, schema sql.Schema) (sql.Table, error) {
   187  	doltDatabase := db.(sqle.Database)
   188  	err := doltDatabase.CreateTable(enginetest.NewContext(d).WithCurrentDB(db.Name()), name, schema)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  
   193  	table, ok, err := doltDatabase.GetTableInsensitive(enginetest.NewContext(d).WithCurrentDB(db.Name()), name)
   194  
   195  	require.NoError(d.t, err)
   196  	require.True(d.t, ok, "table %s not found after creation", name)
   197  	return table, nil
   198  }
   199  
   200  // Dolt doesn't version tables per se, just the entire database. So ignore the name and schema and just create a new
   201  // branch with the given name.
   202  func (d *DoltHarness) NewTableAsOf(db sql.VersionedDatabase, name string, schema sql.Schema, asOf interface{}) sql.Table {
   203  	table, err := d.NewTable(db, name, schema)
   204  	if err != nil {
   205  		require.True(d.t, sql.ErrTableAlreadyExists.Is(err))
   206  	}
   207  
   208  	table, ok, err := db.GetTableInsensitive(enginetest.NewContext(d), name)
   209  	require.NoError(d.t, err)
   210  	require.True(d.t, ok)
   211  
   212  	return table
   213  }
   214  
   215  // Dolt doesn't version tables per se, just the entire database. So ignore the name and schema and just create a new
   216  // branch with the given name.
   217  func (d *DoltHarness) SnapshotTable(db sql.VersionedDatabase, name string, asOf interface{}) error {
   218  	ddb := db.(sqle.Database)
   219  	e := enginetest.NewEngineWithDbs(d.t, d, []sql.Database{db}, nil)
   220  
   221  	if _, err := e.Catalog.FunctionRegistry.Function(dfunctions.CommitFuncName); sql.ErrFunctionNotFound.Is(err) {
   222  		require.NoError(d.t,
   223  			e.Catalog.FunctionRegistry.Register(sql.FunctionN{Name: dfunctions.CommitFuncName, Fn: dfunctions.NewCommitFunc}))
   224  	}
   225  
   226  	asOfString, ok := asOf.(string)
   227  	require.True(d.t, ok)
   228  
   229  	ctx := enginetest.NewContext(d)
   230  	_, iter, err := e.Query(ctx,
   231  		"set @@"+sqle.HeadKey(ddb.Name())+" = COMMIT('-m', 'test commit');")
   232  	require.NoError(d.t, err)
   233  	_, err = sql.RowIterToRows(ctx, iter)
   234  	require.NoError(d.t, err)
   235  
   236  	headHash, err := ctx.GetSessionVariable(ctx, sqle.HeadKey(ddb.Name()))
   237  	require.NoError(d.t, err)
   238  
   239  	ctx = enginetest.NewContext(d)
   240  	// TODO: there's a bug in test setup with transactions, where the HEAD session var gets overwritten on transaction
   241  	//  start, so we quote it here instead
   242  	// query := "insert into dolt_branches (name, hash) values ('" + asOfString + "', @@" + sqle.HeadKey(ddb.Name()) + ")"
   243  	query := "insert into dolt_branches (name, hash) values ('" + asOfString + "', '" + headHash.(string) + "')"
   244  
   245  	_, iter, err = e.Query(ctx,
   246  		query)
   247  	require.NoError(d.t, err)
   248  	_, err = sql.RowIterToRows(ctx, iter)
   249  	require.NoError(d.t, err)
   250  
   251  	return nil
   252  }