github.com/dolthub/go-mysql-server@v0.18.0/enginetest/memory_harness.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package enginetest
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"os"
    21  	"strings"
    22  	"sync"
    23  	"testing"
    24  
    25  	"github.com/dolthub/vitess/go/mysql"
    26  
    27  	sqle "github.com/dolthub/go-mysql-server"
    28  	"github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup"
    29  	"github.com/dolthub/go-mysql-server/memory"
    30  	"github.com/dolthub/go-mysql-server/server"
    31  	"github.com/dolthub/go-mysql-server/sql"
    32  	"github.com/dolthub/go-mysql-server/sql/mysql_db"
    33  )
    34  
    35  const testNumPartitions = 5
    36  
    37  type IndexDriverInitializer func([]sql.Database) sql.IndexDriver
    38  
    39  type MemoryHarness struct {
    40  	name                      string
    41  	parallelism               int
    42  	numTablePartitions        int
    43  	readonly                  bool
    44  	provider                  sql.DatabaseProvider
    45  	indexDriverInitializer    IndexDriverInitializer
    46  	driver                    sql.IndexDriver
    47  	nativeIndexSupport        bool
    48  	skippedQueries            map[string]struct{}
    49  	session                   sql.Session
    50  	retainSession             bool
    51  	setupData                 []setup.SetupScript
    52  	externalProcedureRegistry sql.ExternalStoredProcedureRegistry
    53  	server                    bool
    54  	mu                        *sync.Mutex
    55  }
    56  
    57  var _ Harness = (*MemoryHarness)(nil)
    58  var _ IndexDriverHarness = (*MemoryHarness)(nil)
    59  var _ IndexHarness = (*MemoryHarness)(nil)
    60  var _ VersionedDBHarness = (*MemoryHarness)(nil)
    61  var _ ReadOnlyDatabaseHarness = (*MemoryHarness)(nil)
    62  var _ ForeignKeyHarness = (*MemoryHarness)(nil)
    63  var _ KeylessTableHarness = (*MemoryHarness)(nil)
    64  var _ ClientHarness = (*MemoryHarness)(nil)
    65  var _ ServerHarness = (*MemoryHarness)(nil)
    66  var _ sql.ExternalStoredProcedureProvider = (*MemoryHarness)(nil)
    67  
    68  func NewMemoryHarness(name string, parallelism int, numTablePartitions int, useNativeIndexes bool, driverInitializer IndexDriverInitializer) *MemoryHarness {
    69  	externalProcedureRegistry := sql.NewExternalStoredProcedureRegistry()
    70  	for _, esp := range memory.ExternalStoredProcedures {
    71  		externalProcedureRegistry.Register(esp)
    72  	}
    73  
    74  	var useServer bool
    75  	if _, ok := os.LookupEnv("SERVER_ENGINE_TEST"); ok {
    76  		useServer = true
    77  	}
    78  
    79  	return &MemoryHarness{
    80  		name:                      name,
    81  		numTablePartitions:        numTablePartitions,
    82  		indexDriverInitializer:    driverInitializer,
    83  		parallelism:               parallelism,
    84  		nativeIndexSupport:        useNativeIndexes,
    85  		skippedQueries:            make(map[string]struct{}),
    86  		externalProcedureRegistry: externalProcedureRegistry,
    87  		mu:                        &sync.Mutex{},
    88  		server:                    useServer,
    89  	}
    90  }
    91  
    92  func NewDefaultMemoryHarness() *MemoryHarness {
    93  	return NewMemoryHarness("default", 1, testNumPartitions, true, nil)
    94  }
    95  
    96  func NewReadOnlyMemoryHarness() *MemoryHarness {
    97  	h := NewDefaultMemoryHarness()
    98  	h.readonly = true
    99  	return h
   100  }
   101  
   102  func (m *MemoryHarness) SessionBuilder() server.SessionBuilder {
   103  	return func(ctx context.Context, c *mysql.Conn, addr string) (sql.Session, error) {
   104  		host := ""
   105  		user := ""
   106  		mysqlConnectionUser, ok := c.UserData.(mysql_db.MysqlConnectionUser)
   107  		if ok {
   108  			host = mysqlConnectionUser.Host
   109  			user = mysqlConnectionUser.User
   110  		}
   111  		client := sql.Client{Address: host, User: user, Capabilities: c.Capabilities}
   112  		baseSession := sql.NewBaseSessionWithClientServer(addr, client, c.ConnectionID)
   113  		return memory.NewSession(baseSession, m.getProvider()), nil
   114  	}
   115  }
   116  
   117  // ExternalStoredProcedure implements the sql.ExternalStoredProcedureProvider interface
   118  func (m *MemoryHarness) ExternalStoredProcedure(_ *sql.Context, name string, numOfParams int) (*sql.ExternalStoredProcedureDetails, error) {
   119  	return m.externalProcedureRegistry.LookupByNameAndParamCount(name, numOfParams)
   120  }
   121  
   122  // ExternalStoredProcedures implements the sql.ExternalStoredProcedureProvider interface
   123  func (m *MemoryHarness) ExternalStoredProcedures(_ *sql.Context, name string) ([]sql.ExternalStoredProcedureDetails, error) {
   124  	return m.externalProcedureRegistry.LookupByName(name)
   125  }
   126  
   127  func (m *MemoryHarness) InitializeIndexDriver(dbs []sql.Database) {
   128  	if m.indexDriverInitializer != nil {
   129  		m.driver = m.indexDriverInitializer(dbs)
   130  	}
   131  }
   132  
   133  func (m *MemoryHarness) NewSession() *sql.Context {
   134  	m.session = m.newSession()
   135  	return m.NewContext()
   136  }
   137  
   138  func (m *MemoryHarness) SkipQueryTest(query string) bool {
   139  	_, ok := m.skippedQueries[strings.ToLower(query)]
   140  	return ok
   141  }
   142  
   143  func (m *MemoryHarness) QueriesToSkip(queries ...string) {
   144  	for _, query := range queries {
   145  		m.skippedQueries[strings.ToLower(query)] = struct{}{}
   146  	}
   147  }
   148  
   149  func (m *MemoryHarness) UseServer() {
   150  	m.server = true
   151  }
   152  
   153  func (m *MemoryHarness) IsUsingServer() bool {
   154  	return m.server
   155  }
   156  
   157  type SkippingMemoryHarness struct {
   158  	MemoryHarness
   159  }
   160  
   161  var _ SkippingHarness = (*SkippingMemoryHarness)(nil)
   162  
   163  func NewSkippingMemoryHarness() *SkippingMemoryHarness {
   164  	return &SkippingMemoryHarness{
   165  		MemoryHarness: *NewDefaultMemoryHarness(),
   166  	}
   167  }
   168  
   169  func (s SkippingMemoryHarness) SkipQueryTest(query string) bool {
   170  	return true
   171  }
   172  
   173  func (m *MemoryHarness) Setup(setupData ...[]setup.SetupScript) {
   174  	m.setupData = nil
   175  	for i := range setupData {
   176  		m.setupData = append(m.setupData, setupData[i]...)
   177  	}
   178  	return
   179  }
   180  
   181  func (m *MemoryHarness) NewEngine(t *testing.T) (QueryEngine, error) {
   182  	if !m.retainSession {
   183  		m.session = nil
   184  		m.provider = nil
   185  	}
   186  	engine, err := NewEngine(t, m, m.getProvider(), m.setupData, memory.NewStatsProv())
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  
   191  	if m.server {
   192  		return NewServerQueryEngine(t, engine, m.SessionBuilder())
   193  	}
   194  
   195  	return engine, nil
   196  }
   197  
   198  func (m *MemoryHarness) NewTableAsOf(db sql.VersionedDatabase, name string, schema sql.PrimaryKeySchema, asOf interface{}) sql.Table {
   199  	var fkColl *memory.ForeignKeyCollection
   200  	var baseDb *memory.BaseDatabase
   201  	if memDb, ok := db.(*memory.HistoryDatabase); ok {
   202  		fkColl = memDb.GetForeignKeyCollection()
   203  		baseDb = memDb.BaseDatabase
   204  	} else if memDb, ok := db.(*memory.ReadOnlyDatabase); ok {
   205  		fkColl = memDb.GetForeignKeyCollection()
   206  		baseDb = memDb.BaseDatabase
   207  	} else {
   208  		panic(fmt.Sprintf("unexpected database type %T", db))
   209  	}
   210  	table := memory.NewPartitionedTableRevision(baseDb, name, schema, fkColl, m.numTablePartitions)
   211  	if m.nativeIndexSupport {
   212  		table.EnablePrimaryKeyIndexes()
   213  	}
   214  	if ro, ok := db.(memory.ReadOnlyDatabase); ok {
   215  		ro.HistoryDatabase.AddTableAsOf(name, table, asOf)
   216  	} else {
   217  		db.(*memory.HistoryDatabase).AddTableAsOf(name, table, asOf)
   218  	}
   219  
   220  	m.retainSession = true
   221  
   222  	return table
   223  }
   224  
   225  func (m *MemoryHarness) SnapshotTable(db sql.VersionedDatabase, name string, asOf interface{}) error {
   226  	// Nothing to do for this implementation: the NewTableAsOf method does all the work of creating the snapshot.
   227  	return nil
   228  }
   229  
   230  func (m *MemoryHarness) SupportsNativeIndexCreation() bool {
   231  	return m.nativeIndexSupport
   232  }
   233  
   234  func (m *MemoryHarness) SupportsForeignKeys() bool {
   235  	return true
   236  }
   237  
   238  func (m *MemoryHarness) SupportsKeylessTables() bool {
   239  	return true
   240  }
   241  
   242  func (m *MemoryHarness) Parallelism() int {
   243  	return m.parallelism
   244  }
   245  
   246  func (m *MemoryHarness) NewContext() *sql.Context {
   247  	if m.session == nil {
   248  		m.session = m.newSession()
   249  	}
   250  
   251  	return sql.NewContext(
   252  		context.Background(),
   253  		sql.WithSession(m.session),
   254  	)
   255  }
   256  
   257  func (m *MemoryHarness) newSession() *memory.Session {
   258  	baseSession := NewBaseSession()
   259  	session := memory.NewSession(baseSession, m.getProvider())
   260  	if m.driver != nil {
   261  		session.GetIndexRegistry().RegisterIndexDriver(m.driver)
   262  	}
   263  	return session
   264  }
   265  
   266  func (m *MemoryHarness) NewContextWithClient(client sql.Client) *sql.Context {
   267  	baseSession := sql.NewBaseSessionWithClientServer("address", client, 1)
   268  
   269  	return sql.NewContext(
   270  		context.Background(),
   271  		sql.WithSession(memory.NewSession(baseSession, m.getProvider())),
   272  	)
   273  }
   274  
   275  func (m *MemoryHarness) IndexDriver(dbs []sql.Database) sql.IndexDriver {
   276  	if m.indexDriverInitializer != nil {
   277  		return m.indexDriverInitializer(dbs)
   278  	}
   279  	return nil
   280  }
   281  
   282  func (m *MemoryHarness) newDatabase(name string) sql.Database {
   283  	ctx := m.NewContext()
   284  
   285  	err := m.getProvider().(*memory.DbProvider).CreateDatabase(ctx, name)
   286  	if err != nil {
   287  		panic(err)
   288  	}
   289  
   290  	db, _ := m.getProvider().Database(ctx, name)
   291  	return db
   292  }
   293  
   294  func (m *MemoryHarness) WithProvider(provider sql.DatabaseProvider) *MemoryHarness {
   295  	ret := *m
   296  	ret.provider = provider
   297  	return &ret
   298  }
   299  
   300  func (m *MemoryHarness) getProvider() sql.DatabaseProvider {
   301  	m.mu.Lock()
   302  	defer m.mu.Unlock()
   303  
   304  	if m.provider == nil {
   305  		m.provider = m.NewDatabaseProvider().(*memory.DbProvider)
   306  	}
   307  
   308  	return m.provider
   309  }
   310  
   311  func (m *MemoryHarness) NewDatabaseProvider() sql.MutableDatabaseProvider {
   312  	return memory.NewDBProviderWithOpts(
   313  		memory.NativeIndexProvider(m.nativeIndexSupport),
   314  		memory.HistoryProvider(true))
   315  }
   316  
   317  func (m *MemoryHarness) Provider() *memory.DbProvider {
   318  	return m.getProvider().(*memory.DbProvider)
   319  }
   320  
   321  func (m *MemoryHarness) NewDatabases(names ...string) []sql.Database {
   322  	var dbs []sql.Database
   323  	for _, name := range names {
   324  		dbs = append(dbs, m.newDatabase(name))
   325  	}
   326  	return dbs
   327  }
   328  
   329  func (m *MemoryHarness) NewReadOnlyEngine(provider sql.DatabaseProvider) (QueryEngine, error) {
   330  	dbs := make([]sql.Database, 0)
   331  	for _, db := range provider.AllDatabases(m.NewContext()) {
   332  		dbs = append(dbs, memory.ReadOnlyDatabase{HistoryDatabase: db.(*memory.HistoryDatabase)})
   333  	}
   334  
   335  	readOnlyProvider := memory.NewDBProviderWithOpts(memory.WithDbsOption(dbs))
   336  	m.provider = readOnlyProvider.(*memory.DbProvider)
   337  
   338  	return NewEngineWithProvider(nil, m, readOnlyProvider), nil
   339  }
   340  
   341  func (m *MemoryHarness) ValidateEngine(ctx *sql.Context, e *sqle.Engine) error {
   342  	return sanityCheckEngine(ctx, e)
   343  }
   344  
   345  func sanityCheckEngine(ctx *sql.Context, e *sqle.Engine) (err error) {
   346  	for _, db := range e.Analyzer.Catalog.AllDatabases(ctx) {
   347  		if err = sanityCheckDatabase(ctx, db); err != nil {
   348  			return err
   349  		}
   350  	}
   351  	return
   352  }
   353  
   354  func sanityCheckDatabase(ctx *sql.Context, db sql.Database) error {
   355  	names, err := db.GetTableNames(ctx)
   356  	if err != nil {
   357  		return err
   358  	}
   359  	for _, name := range names {
   360  		t, ok, err := db.GetTableInsensitive(ctx, name)
   361  		if err != nil {
   362  			return err
   363  		}
   364  		if !ok {
   365  			return fmt.Errorf("expected to find table %s", name)
   366  		}
   367  		if t.Name() != name {
   368  			return fmt.Errorf("unexpected table name (%s !=  %s)", name, t.Name())
   369  		}
   370  	}
   371  	return nil
   372  }