github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/testutil.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"strings"
    23  
    24  	sqle "github.com/dolthub/go-mysql-server"
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/dolthub/vitess/go/vt/sqlparser"
    27  
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    34  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    35  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
    36  	"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
    37  	config2 "github.com/dolthub/dolt/go/libraries/utils/config"
    38  	"github.com/dolthub/dolt/go/libraries/utils/filesys"
    39  	"github.com/dolthub/dolt/go/store/prolly/tree"
    40  	"github.com/dolthub/dolt/go/store/types"
    41  	"github.com/dolthub/dolt/go/store/val"
    42  )
    43  
    44  // ExecuteSql executes all the SQL non-select statements given in the string against the root value given and returns
    45  // the updated root, or an error. Statements in the input string are split by `;\n`
    46  func ExecuteSql(dEnv *env.DoltEnv, root doltdb.RootValue, statements string) (doltdb.RootValue, error) {
    47  	tmpDir, err := dEnv.TempTableFilesDir()
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: tmpDir}
    53  	db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	engine, ctx, err := NewTestEngine(dEnv, context.Background(), db)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	err = ctx.Session.SetSessionVariable(ctx, sql.AutoCommitSessionVar, false)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	for _, query := range strings.Split(statements, ";\n") {
    69  		if len(strings.Trim(query, " ")) == 0 {
    70  			continue
    71  		}
    72  
    73  		sqlStatement, err := sqlparser.Parse(query)
    74  		if err != nil {
    75  			return nil, err
    76  		}
    77  
    78  		var execErr error
    79  		switch sqlStatement.(type) {
    80  		case *sqlparser.Show:
    81  			return nil, errors.New("Show statements aren't handled")
    82  		case *sqlparser.Select, *sqlparser.OtherRead:
    83  			return nil, errors.New("Select statements aren't handled")
    84  		case *sqlparser.Insert:
    85  			var rowIter sql.RowIter
    86  			_, rowIter, execErr = engine.Query(ctx, query)
    87  			if execErr == nil {
    88  				execErr = drainIter(ctx, rowIter)
    89  			}
    90  		case *sqlparser.DDL, *sqlparser.AlterTable:
    91  			var rowIter sql.RowIter
    92  			_, rowIter, execErr = engine.Query(ctx, query)
    93  			if execErr == nil {
    94  				execErr = drainIter(ctx, rowIter)
    95  			}
    96  		default:
    97  			return nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query)
    98  		}
    99  
   100  		if execErr != nil {
   101  			return nil, execErr
   102  		}
   103  	}
   104  
   105  	err = dsess.DSessFromSess(ctx.Session).CommitTransaction(ctx, ctx.GetTransaction())
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	return db.GetRoot(ctx)
   111  }
   112  
   113  func NewTestSQLCtxWithProvider(ctx context.Context, pro dsess.DoltDatabaseProvider, statsPro sql.StatsProvider) *sql.Context {
   114  	s, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, config2.NewMapConfig(make(map[string]string)), branch_control.CreateDefaultController(ctx), statsPro)
   115  	if err != nil {
   116  		panic(err)
   117  	}
   118  
   119  	s.SetCurrentDatabase("dolt")
   120  	return sql.NewContext(
   121  		ctx,
   122  		sql.WithSession(s),
   123  	)
   124  }
   125  
   126  // NewTestEngine creates a new default engine, and a *sql.Context and initializes indexes and schema fragments.
   127  func NewTestEngine(dEnv *env.DoltEnv, ctx context.Context, db dsess.SqlDatabase) (*sqle.Engine, *sql.Context, error) {
   128  	b := env.GetDefaultInitBranch(dEnv.Config)
   129  	pro, err := NewDoltDatabaseProviderWithDatabase(b, dEnv.FS, db, dEnv.FS)
   130  	if err != nil {
   131  		return nil, nil, err
   132  	}
   133  
   134  	engine := sqle.NewDefault(pro)
   135  
   136  	sqlCtx := NewTestSQLCtxWithProvider(ctx, pro, nil)
   137  	sqlCtx.SetCurrentDatabase(db.Name())
   138  	return engine, sqlCtx, nil
   139  }
   140  
   141  // ExecuteSelect executes the select statement given and returns the resulting rows, or an error if one is encountered.
   142  func ExecuteSelect(dEnv *env.DoltEnv, root doltdb.RootValue, query string) ([]sql.Row, error) {
   143  	dbData := env.DbData{
   144  		Ddb: dEnv.DoltDB,
   145  		Rsw: dEnv.RepoStateWriter(),
   146  		Rsr: dEnv.RepoStateReader(),
   147  	}
   148  
   149  	tmpDir, err := dEnv.TempTableFilesDir()
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: tmpDir}
   155  	db, err := NewDatabase(context.Background(), "dolt", dbData, opts)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	engine, ctx, err := NewTestEngine(dEnv, context.Background(), db)
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  
   165  	_, rowIter, err := engine.Query(ctx, query)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  
   170  	var (
   171  		rows   []sql.Row
   172  		rowErr error
   173  		row    sql.Row
   174  	)
   175  	for row, rowErr = rowIter.Next(ctx); rowErr == nil; row, rowErr = rowIter.Next(ctx) {
   176  		rows = append(rows, row)
   177  	}
   178  
   179  	if rowErr != io.EOF {
   180  		return nil, rowErr
   181  	}
   182  
   183  	return rows, nil
   184  }
   185  
   186  // Returns the dolt rows given transformed to sql rows. Exactly the columns in the schema provided are present in the
   187  // final output rows, even if the input rows contain different columns. The tag numbers for columns in the row and
   188  // schema given must match.
   189  func ToSqlRows(sch schema.Schema, rs ...row.Row) []sql.Row {
   190  	sqlRows := make([]sql.Row, len(rs))
   191  	compressedSch := CompressSchema(sch)
   192  	for i := range rs {
   193  		sqlRows[i], _ = sqlutil.DoltRowToSqlRow(CompressRow(sch, rs[i]), compressedSch)
   194  	}
   195  	return sqlRows
   196  }
   197  
   198  // Rewrites the tag numbers for the schema given to start at 0, just like result set schemas. If one or more column
   199  // names are given, only those column names are included in the compressed schema. The column list can also be used to
   200  // reorder the columns as necessary.
   201  func CompressSchema(sch schema.Schema, colNames ...string) schema.Schema {
   202  	var itag uint64
   203  	var cols []schema.Column
   204  
   205  	if len(colNames) > 0 {
   206  		cols = make([]schema.Column, len(colNames))
   207  		for _, colName := range colNames {
   208  			column, ok := sch.GetAllCols().GetByName(colName)
   209  			if !ok {
   210  				panic("No column found for column name " + colName)
   211  			}
   212  			column.Tag = itag
   213  			cols[itag] = column
   214  			itag++
   215  		}
   216  	} else {
   217  		cols = make([]schema.Column, sch.GetAllCols().Size())
   218  		sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   219  			col.Tag = itag
   220  			cols[itag] = col
   221  			itag++
   222  			return false, nil
   223  		})
   224  	}
   225  
   226  	colCol := schema.NewColCollection(cols...)
   227  	return schema.UnkeyedSchemaFromCols(colCol)
   228  }
   229  
   230  // Rewrites the tag numbers for the schemas given to start at 0, just like result set schemas.
   231  func CompressSchemas(schs ...schema.Schema) schema.Schema {
   232  	var itag uint64
   233  	var cols []schema.Column
   234  
   235  	cols = make([]schema.Column, 0)
   236  	for _, sch := range schs {
   237  		sch.GetAllCols().IterInSortedOrder(func(tag uint64, col schema.Column) (stop bool) {
   238  			col.Tag = itag
   239  			cols = append(cols, col)
   240  			itag++
   241  			return false
   242  		})
   243  	}
   244  
   245  	colCol := schema.NewColCollection(cols...)
   246  	return schema.UnkeyedSchemaFromCols(colCol)
   247  }
   248  
   249  // Rewrites the tag numbers for the row given to begin at zero and be contiguous, just like result set schemas. We don't
   250  // want to just use the field mappings in the result set schema used by sqlselect, since that would only demonstrate
   251  // that the code was consistent with itself, not actually correct.
   252  func CompressRow(sch schema.Schema, r row.Row) row.Row {
   253  	var itag uint64
   254  	compressedRow := make(row.TaggedValues)
   255  
   256  	// TODO: this is probably incorrect and will break for schemas where the tag numbering doesn't match the declared order
   257  	sch.GetAllCols().IterInSortedOrder(func(tag uint64, col schema.Column) (stop bool) {
   258  		if val, ok := r.GetColVal(tag); ok {
   259  			compressedRow[itag] = val
   260  		}
   261  		itag++
   262  		return false
   263  	})
   264  
   265  	// call to compress schema is a no-op in most cases
   266  	r, err := row.New(types.Format_Default, CompressSchema(sch), compressedRow)
   267  
   268  	if err != nil {
   269  		panic(err)
   270  	}
   271  
   272  	return r
   273  }
   274  
   275  // SubsetSchema returns a schema that is a subset of the schema given, with keys and constraints removed. Column names
   276  // must be verified before subsetting. Unrecognized column names will cause a panic.
   277  func SubsetSchema(sch schema.Schema, colNames ...string) schema.Schema {
   278  	srcColls := sch.GetAllCols()
   279  
   280  	var cols []schema.Column
   281  	for _, name := range colNames {
   282  		if col, ok := srcColls.GetByName(name); !ok {
   283  			panic("Unrecognized name " + name)
   284  		} else {
   285  			cols = append(cols, col)
   286  		}
   287  	}
   288  	colColl := schema.NewColCollection(cols...)
   289  	return schema.UnkeyedSchemaFromCols(colColl)
   290  }
   291  
   292  // DoltSchemaFromAlterableTable is a utility for integration tests
   293  func DoltSchemaFromAlterableTable(t *AlterableDoltTable) schema.Schema {
   294  	return t.sch
   295  }
   296  
   297  // DoltTableFromAlterableTable is a utility for integration tests
   298  func DoltTableFromAlterableTable(ctx *sql.Context, t *AlterableDoltTable) *doltdb.Table {
   299  	dt, err := t.DoltTable.DoltTable(ctx)
   300  	if err != nil {
   301  		panic(err)
   302  	}
   303  	return dt
   304  }
   305  
   306  func drainIter(ctx *sql.Context, iter sql.RowIter) error {
   307  	for {
   308  		_, err := iter.Next(ctx)
   309  		if err == io.EOF {
   310  			break
   311  		} else if err != nil {
   312  			closeErr := iter.Close(ctx)
   313  			if closeErr != nil {
   314  				panic(fmt.Errorf("%v\n%v", err, closeErr))
   315  			}
   316  			return err
   317  		}
   318  	}
   319  	return iter.Close(ctx)
   320  }
   321  
   322  func CreateEnvWithSeedData() (*env.DoltEnv, error) {
   323  	const seedData = `
   324  	CREATE TABLE people (
   325  	    id varchar(36) primary key,
   326  	    name varchar(40) not null,
   327  	    age int unsigned,
   328  	    is_married int,
   329  	    title varchar(40),
   330  	    INDEX idx_name (name)
   331  	);
   332  	INSERT INTO people VALUES
   333  		('00000000-0000-0000-0000-000000000000', 'Bill Billerson', 32, 1, 'Senior Dufus'),
   334  		('00000000-0000-0000-0000-000000000001', 'John Johnson', 25, 0, 'Dufus'),
   335  		('00000000-0000-0000-0000-000000000002', 'Rob Robertson', 21, 0, '');`
   336  
   337  	ctx := context.Background()
   338  	dEnv := CreateTestEnv()
   339  	root, err := dEnv.WorkingRoot(ctx)
   340  	if err != nil {
   341  		return nil, err
   342  	}
   343  
   344  	root, err = ExecuteSql(dEnv, root, seedData)
   345  	if err != nil {
   346  		return nil, err
   347  	}
   348  
   349  	err = dEnv.UpdateWorkingRoot(ctx, root)
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  
   354  	return dEnv, nil
   355  }
   356  
   357  // CreateEmptyTestDatabase creates a test database without any data in it.
   358  func CreateEmptyTestDatabase() (*env.DoltEnv, error) {
   359  	dEnv := CreateTestEnv()
   360  	err := CreateEmptyTestTable(dEnv, PeopleTableName, PeopleTestSchema)
   361  	if err != nil {
   362  		return nil, err
   363  	}
   364  
   365  	err = CreateEmptyTestTable(dEnv, EpisodesTableName, EpisodesTestSchema)
   366  	if err != nil {
   367  		return nil, err
   368  	}
   369  
   370  	err = CreateEmptyTestTable(dEnv, AppearancesTableName, AppearancesTestSchema)
   371  	if err != nil {
   372  		return nil, err
   373  	}
   374  
   375  	return dEnv, nil
   376  }
   377  
   378  const (
   379  	TestHomeDirPrefix = "/user/dolt/"
   380  	WorkingDirPrefix  = "/user/dolt/datasets/"
   381  )
   382  
   383  // CreateTestEnv creates a new DoltEnv suitable for testing. The CreateTestEnvWithName
   384  // function should generally be preferred over this method, especially when working
   385  // with tests using multiple databases within a MultiRepoEnv.
   386  func CreateTestEnv() *env.DoltEnv {
   387  	return CreateTestEnvWithName("test")
   388  }
   389  
   390  // CreateTestEnvWithName creates a new DoltEnv suitable for testing and uses
   391  // the specified name to distinguish it from other test envs. This function
   392  // should generally be preferred over CreateTestEnv, especially when working with
   393  // tests using multiple databases within a MultiRepoEnv.
   394  func CreateTestEnvWithName(envName string) *env.DoltEnv {
   395  	const name = "billy bob"
   396  	const email = "bigbillieb@fake.horse"
   397  	initialDirs := []string{TestHomeDirPrefix + envName, WorkingDirPrefix + envName}
   398  	homeDirFunc := func() (string, error) { return TestHomeDirPrefix + envName, nil }
   399  	fs := filesys.NewInMemFS(initialDirs, nil, WorkingDirPrefix+envName)
   400  	dEnv := env.Load(context.Background(), homeDirFunc, fs, doltdb.InMemDoltDB+envName, "test")
   401  	cfg, _ := dEnv.Config.GetConfig(env.GlobalConfig)
   402  	cfg.SetStrings(map[string]string{
   403  		config2.UserNameKey:  name,
   404  		config2.UserEmailKey: email,
   405  	})
   406  
   407  	err := dEnv.InitRepo(context.Background(), types.Format_Default, name, email, env.DefaultInitBranch)
   408  
   409  	if err != nil {
   410  		panic("Failed to initialize environment:" + err.Error())
   411  	}
   412  
   413  	return dEnv
   414  }
   415  
   416  // CreateEmptyTestTable creates a new test table with the name, schema, and rows given.
   417  func CreateEmptyTestTable(dEnv *env.DoltEnv, tableName string, sch schema.Schema) error {
   418  	ctx := context.Background()
   419  	root, err := dEnv.WorkingRoot(ctx)
   420  	if err != nil {
   421  		return err
   422  	}
   423  
   424  	vrw := dEnv.DoltDB.ValueReadWriter()
   425  	ns := dEnv.DoltDB.NodeStore()
   426  
   427  	rows, err := durable.NewEmptyIndex(ctx, vrw, ns, sch)
   428  	if err != nil {
   429  		return err
   430  	}
   431  
   432  	indexSet, err := durable.NewIndexSetWithEmptyIndexes(ctx, vrw, ns, sch)
   433  	if err != nil {
   434  		return err
   435  	}
   436  
   437  	tbl, err := doltdb.NewTable(ctx, vrw, ns, sch, rows, indexSet, nil)
   438  	if err != nil {
   439  		return err
   440  	}
   441  
   442  	newRoot, err := root.PutTable(ctx, doltdb.TableName{Name: tableName}, tbl)
   443  	if err != nil {
   444  		return err
   445  	}
   446  
   447  	return dEnv.UpdateWorkingRoot(ctx, newRoot)
   448  }
   449  
   450  // CreateTestDatabase creates a test database with the test data set in it.
   451  func CreateTestDatabase() (*env.DoltEnv, error) {
   452  	ctx := context.Background()
   453  	dEnv, err := CreateEmptyTestDatabase()
   454  	if err != nil {
   455  		return nil, err
   456  	}
   457  
   458  	const simpsonsRowData = `
   459  	INSERT INTO people VALUES
   460  		(0, "Homer", "Simpson", 1, 40, 8.5, NULL, NULL),
   461  		(1, "Marge", "Simpson", 1, 38, 8, "00000000-0000-0000-0000-000000000001", 111),
   462  		(2, "Bart", "Simpson", 0, 10, 9, "00000000-0000-0000-0000-000000000002", 222),
   463  		(3, "Lisa", "Simpson", 0, 8, 10, "00000000-0000-0000-0000-000000000003", 333),
   464  		(4, "Moe", "Szyslak", 0, 48, 6.5, "00000000-0000-0000-0000-000000000004", 444),
   465  		(5, "Barney", "Gumble", 0, 40, 4, "00000000-0000-0000-0000-000000000005", 555);
   466  	INSERT INTO episodes VALUES 
   467  		(1, "Simpsons Roasting On an Open Fire", "1989-12-18 03:00:00", 8.0),
   468  		(2, "Bart the Genius", "1990-01-15 03:00:00", 9.0),
   469  		(3, "Homer's Odyssey", "1990-01-22 03:00:00", 7.0),
   470  		(4, "There's No Disgrace Like Home", "1990-01-29 03:00:00", 8.5);
   471  	INSERT INTO appearances VALUES 
   472  		(0, 1, "Homer is great in this one"),
   473  		(1, 1, "Marge is here too"),
   474  		(0, 2, "Homer is great in this one too"),
   475  		(2, 2, "This episode is named after Bart"),
   476  		(3, 2, "Lisa is here too"),
   477  		(4, 2, "I think there's a prank call scene"),
   478  		(0, 3, "Homer is in every episode"),
   479  		(1, 3, "Marge shows up a lot too"),
   480  		(3, 3, "Lisa is the best Simpson"),
   481  		(5, 3, "I'm making this all up");`
   482  
   483  	root, err := dEnv.WorkingRoot(ctx)
   484  	if err != nil {
   485  		return nil, err
   486  	}
   487  
   488  	root, err = ExecuteSql(dEnv, root, simpsonsRowData)
   489  	if err != nil {
   490  		return nil, err
   491  	}
   492  
   493  	err = dEnv.UpdateWorkingRoot(ctx, root)
   494  	if err != nil {
   495  		return nil, err
   496  	}
   497  
   498  	return dEnv, nil
   499  }
   500  
   501  func SqlRowsFromDurableIndex(idx durable.Index, sch schema.Schema) ([]sql.Row, error) {
   502  	ctx := context.Background()
   503  	var sqlRows []sql.Row
   504  	if types.Format_Default == types.Format_DOLT {
   505  		rowData := durable.ProllyMapFromIndex(idx)
   506  		kd, vd := rowData.Descriptors()
   507  		iter, err := rowData.IterAll(ctx)
   508  		if err != nil {
   509  			return nil, err
   510  		}
   511  		for {
   512  			var k, v val.Tuple
   513  			k, v, err = iter.Next(ctx)
   514  			if err == io.EOF {
   515  				break
   516  			} else if err != nil {
   517  				return nil, err
   518  			}
   519  			sqlRow, err := sqlRowFromTuples(sch, kd, vd, k, v)
   520  			if err != nil {
   521  				return nil, err
   522  			}
   523  			sqlRows = append(sqlRows, sqlRow)
   524  		}
   525  
   526  	} else {
   527  		// types.Format_LD_1
   528  		rowData := durable.NomsMapFromIndex(idx)
   529  		_ = rowData.IterAll(ctx, func(key, value types.Value) error {
   530  			r, err := row.FromNoms(sch, key.(types.Tuple), value.(types.Tuple))
   531  			if err != nil {
   532  				return err
   533  			}
   534  			sqlRow, err := sqlutil.DoltRowToSqlRow(r, sch)
   535  			if err != nil {
   536  				return err
   537  			}
   538  			sqlRows = append(sqlRows, sqlRow)
   539  			return nil
   540  		})
   541  	}
   542  	return sqlRows, nil
   543  }
   544  
   545  func sqlRowFromTuples(sch schema.Schema, kd, vd val.TupleDesc, k, v val.Tuple) (sql.Row, error) {
   546  	var err error
   547  	ctx := context.Background()
   548  	r := make(sql.Row, sch.GetAllCols().Size())
   549  	keyless := schema.IsKeyless(sch)
   550  
   551  	for i, col := range sch.GetAllCols().GetColumns() {
   552  		pos, ok := sch.GetPKCols().TagToIdx[col.Tag]
   553  		if ok {
   554  			r[i], err = tree.GetField(ctx, kd, pos, k, nil)
   555  			if err != nil {
   556  				return nil, err
   557  			}
   558  		}
   559  
   560  		pos, ok = sch.GetNonPKCols().TagToIdx[col.Tag]
   561  		if keyless {
   562  			pos += 1 // compensate for cardinality field
   563  		}
   564  		if ok {
   565  			r[i], err = tree.GetField(ctx, vd, pos, v, nil)
   566  			if err != nil {
   567  				return nil, err
   568  			}
   569  		}
   570  	}
   571  	return r, nil
   572  }