github.com/dolthub/go-mysql-server@v0.18.0/enginetest/initialization.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package enginetest
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  	"sync/atomic"
    21  	"testing"
    22  
    23  	"github.com/stretchr/testify/require"
    24  
    25  	sqle "github.com/dolthub/go-mysql-server"
    26  	"github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup"
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  	"github.com/dolthub/go-mysql-server/sql/analyzer"
    29  	"github.com/dolthub/go-mysql-server/sql/information_schema"
    30  )
    31  
    32  func NewContext(harness Harness) *sql.Context {
    33  	return newContextSetup(harness.NewContext())
    34  }
    35  
    36  func NewContextWithClient(harness ClientHarness, client sql.Client) *sql.Context {
    37  	return newContextSetup(harness.NewContextWithClient(client))
    38  }
    39  
    40  // TODO: remove
    41  func NewContextWithEngine(harness Harness, engine QueryEngine) *sql.Context {
    42  	return NewContext(harness)
    43  }
    44  
    45  var pid uint64
    46  
    47  func newContextSetup(ctx *sql.Context) *sql.Context {
    48  	// Select a current database if there isn't one yet
    49  	if ctx.GetCurrentDatabase() == "" {
    50  		ctx.SetCurrentDatabase("mydb")
    51  	}
    52  
    53  	ctx.ApplyOpts(sql.WithPid(atomic.AddUint64(&pid, 1)))
    54  
    55  	// We don't want to show any external procedures in our engine tests, so we exclude them
    56  	_ = ctx.SetSessionVariable(ctx, "show_external_procedures", false)
    57  
    58  	return ctx
    59  }
    60  
    61  func NewSession(harness Harness) *sql.Context {
    62  	th, ok := harness.(TransactionHarness)
    63  	if !ok {
    64  		panic("Cannot use NewSession except on a TransactionHarness")
    65  	}
    66  
    67  	ctx := th.NewSession()
    68  	currentDB := ctx.GetCurrentDatabase()
    69  	if currentDB == "" {
    70  		currentDB = "mydb"
    71  		ctx.SetCurrentDatabase(currentDB)
    72  	}
    73  
    74  	ctx.ApplyOpts(sql.WithPid(atomic.AddUint64(&pid, 1)))
    75  
    76  	return ctx
    77  }
    78  
    79  // NewBaseSession returns a new BaseSession compatible with these tests. Most tests will work with any session
    80  // implementation, but for full compatibility use a session based on this one.
    81  func NewBaseSession() *sql.BaseSession {
    82  	return sql.NewBaseSessionWithClientServer("address", sql.Client{Address: "localhost", User: "root"}, 1)
    83  }
    84  
    85  // NewEngineWithProvider returns a new engine with the specified provider
    86  func NewEngineWithProvider(_ *testing.T, harness Harness, provider sql.DatabaseProvider) *sqle.Engine {
    87  	var a *analyzer.Analyzer
    88  
    89  	if harness.Parallelism() > 1 {
    90  		a = analyzer.NewBuilder(provider).WithParallelism(harness.Parallelism()).Build()
    91  	} else {
    92  		a = analyzer.NewDefault(provider)
    93  	}
    94  
    95  	// All tests will run with all privileges on the built-in root account
    96  	a.Catalog.MySQLDb.AddRootAccount()
    97  	// Almost no tests require an information schema that can be updated, but test setup makes it difficult to not
    98  	// provide everywhere
    99  	a.Catalog.InfoSchema = information_schema.NewInformationSchemaDatabase()
   100  
   101  	engine := sqle.New(a, new(sqle.Config))
   102  
   103  	if idh, ok := harness.(IndexDriverHarness); ok {
   104  		idh.InitializeIndexDriver(engine.Analyzer.Catalog.AllDatabases(NewContext(harness)))
   105  	}
   106  
   107  	return engine
   108  }
   109  
   110  // NewEngine creates an engine and sets it up for testing using harness, provider, and setup data given.
   111  func NewEngine(t *testing.T, harness Harness, dbProvider sql.DatabaseProvider, setupData []setup.SetupScript, statsProvider sql.StatsProvider) (*sqle.Engine, error) {
   112  	e := NewEngineWithProvider(t, harness, dbProvider)
   113  	e.Analyzer.Catalog.StatsProvider = statsProvider
   114  	ctx := NewContext(harness)
   115  
   116  	var supportsIndexes bool
   117  	if ih, ok := harness.(IndexHarness); ok && ih.SupportsNativeIndexCreation() {
   118  		supportsIndexes = true
   119  	}
   120  
   121  	// TODO: remove ths, make it explicit everywhere
   122  	if len(setupData) == 0 {
   123  		setupData = setup.MydbData
   124  	}
   125  	return RunSetupScripts(ctx, e, setupData, supportsIndexes)
   126  }
   127  
   128  // RunSetupScripts runs the given setup scripts on the given engine, returning any error
   129  func RunSetupScripts(ctx *sql.Context, e *sqle.Engine, scripts []setup.SetupScript, createIndexes bool) (*sqle.Engine, error) {
   130  	for i := range scripts {
   131  		for _, s := range scripts[i] {
   132  			if !createIndexes {
   133  				if strings.Contains("create index", s) {
   134  					continue
   135  				}
   136  			}
   137  			// ctx.GetLogger().Warnf("running query %s\n", s)
   138  			ctx := ctx.WithQuery(s)
   139  			_, iter, err := e.Query(ctx, s)
   140  			if err != nil {
   141  				return nil, err
   142  			}
   143  			_, err = sql.RowIterToRows(ctx, iter)
   144  			if err != nil {
   145  				return nil, err
   146  			}
   147  		}
   148  	}
   149  	return e, nil
   150  }
   151  
   152  func MustQuery(ctx *sql.Context, e QueryEngine, q string) (sql.Schema, []sql.Row) {
   153  	sch, iter, err := e.Query(ctx, q)
   154  	if err != nil {
   155  		panic(fmt.Sprintf("err running query %s: %s", q, err))
   156  	}
   157  	rows, err := sql.RowIterToRows(ctx, iter)
   158  	if err != nil {
   159  		panic(fmt.Sprintf("err running query %s: %s", q, err))
   160  	}
   161  	return sch, rows
   162  }
   163  
   164  func mustNewEngine(t *testing.T, h Harness) QueryEngine {
   165  	e, err := h.NewEngine(t)
   166  	if err != nil {
   167  		require.NoError(t, err)
   168  	}
   169  	return e
   170  }