github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/sqlbatch_test.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"math/rand"
    21  	"testing"
    22  
    23  	"github.com/google/go-cmp/cmp"
    24  	"github.com/google/uuid"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    30  	. "github.com/dolthub/dolt/go/libraries/doltcore/sql/sqltestutil"
    31  	"github.com/dolthub/dolt/go/store/types"
    32  )
    33  
    34  func TestSqlBatchInserts(t *testing.T) {
    35  	insertStatements := []string{
    36  		`insert into people (id, first_name, last_name, is_married, age, rating, uuid, num_episodes) values
    37  					(7, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
    38  		`insert into people values
    39  					(8, "Milhouse", "VanHouten", false, 1, 5.1, '00000000-0000-0000-0000-000000000008', 677)`,
    40  		`insert into people (id, first_name, last_name) values (9, "Clancey", "Wiggum")`,
    41  		`insert into people (id, first_name, last_name) values
    42  					(10, "Montgomery", "Burns"), (11, "Ned", "Flanders")`,
    43  		`insert into episodes (id, name) values (5, "Bart the General"), (6, "Moaning Lisa")`,
    44  		`insert into episodes (id, name) values (7, "The Call of the Simpsons"), (8, "The Telltale Head")`,
    45  		`insert into episodes (id, name) values (9, "Life on the Fast Lane")`,
    46  		`insert into appearances (character_id, episode_id) values (7,5), (7,6)`,
    47  		`insert into appearances (character_id, episode_id) values (8,7)`,
    48  		`insert into appearances (character_id, episode_id) values (9,8), (9,9)`,
    49  		`insert into appearances (character_id, episode_id) values (10,5), (10,6)`,
    50  		`insert into appearances (character_id, episode_id) values (11,9)`,
    51  	}
    52  
    53  	// Shuffle the inserts so that different tables are interleaved. We're not giving a seed here, so this is
    54  	// deterministic.
    55  	rand.Shuffle(len(insertStatements),
    56  		func(i, j int) {
    57  			insertStatements[i], insertStatements[j] = insertStatements[j], insertStatements[i]
    58  		})
    59  
    60  	dEnv := dtestutils.CreateTestEnv()
    61  	ctx := context.Background()
    62  
    63  	CreateTestDatabase(dEnv, t)
    64  	root, _ := dEnv.WorkingRoot(ctx)
    65  
    66  	db := NewDatabase("dolt", dEnv.DbData())
    67  	engine, sqlCtx, err := NewTestEngine(ctx, db, root)
    68  	require.NoError(t, err)
    69  	DSessFromSess(sqlCtx.Session).EnableBatchedMode()
    70  
    71  	for _, stmt := range insertStatements {
    72  		_, rowIter, err := engine.Query(sqlCtx, stmt)
    73  		require.NoError(t, err)
    74  		require.NoError(t, drainIter(sqlCtx, rowIter))
    75  	}
    76  
    77  	// Before committing the batch, the database should be unchanged from its original state
    78  	allPeopleRows, err := GetAllRows(root, PeopleTableName)
    79  	require.NoError(t, err)
    80  	allEpsRows, err := GetAllRows(root, EpisodesTableName)
    81  	require.NoError(t, err)
    82  	allAppearanceRows, err := GetAllRows(root, AppearancesTableName)
    83  	require.NoError(t, err)
    84  
    85  	assert.ElementsMatch(t, AllPeopleRows, allPeopleRows)
    86  	assert.ElementsMatch(t, AllEpsRows, allEpsRows)
    87  	assert.ElementsMatch(t, AllAppsRows, allAppearanceRows)
    88  
    89  	// Now commit the batch and check for new rows
    90  	err = db.Flush(sqlCtx)
    91  	require.NoError(t, err)
    92  
    93  	var expectedPeople, expectedEpisodes, expectedAppearances []row.Row
    94  
    95  	expectedPeople = append(expectedPeople, AllPeopleRows...)
    96  	expectedPeople = append(expectedPeople,
    97  		NewPeopleRowWithOptionalFields(7, "Maggie", "Simpson", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000007"), 677),
    98  		NewPeopleRowWithOptionalFields(8, "Milhouse", "VanHouten", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000008"), 677),
    99  		newPeopleRow(9, "Clancey", "Wiggum"),
   100  		newPeopleRow(10, "Montgomery", "Burns"),
   101  		newPeopleRow(11, "Ned", "Flanders"),
   102  	)
   103  
   104  	expectedEpisodes = append(expectedEpisodes, AllEpsRows...)
   105  	expectedEpisodes = append(expectedEpisodes,
   106  		newEpsRow(5, "Bart the General"),
   107  		newEpsRow(6, "Moaning Lisa"),
   108  		newEpsRow(7, "The Call of the Simpsons"),
   109  		newEpsRow(8, "The Telltale Head"),
   110  		newEpsRow(9, "Life on the Fast Lane"),
   111  	)
   112  
   113  	expectedAppearances = append(expectedAppearances, AllAppsRows...)
   114  	expectedAppearances = append(expectedAppearances,
   115  		newAppsRow(7, 5),
   116  		newAppsRow(7, 6),
   117  		newAppsRow(8, 7),
   118  		newAppsRow(9, 8),
   119  		newAppsRow(9, 9),
   120  		newAppsRow(10, 5),
   121  		newAppsRow(10, 6),
   122  		newAppsRow(11, 9),
   123  	)
   124  
   125  	root, err = db.GetRoot(sqlCtx)
   126  	require.NoError(t, err)
   127  	allPeopleRows, err = GetAllRows(root, PeopleTableName)
   128  	require.NoError(t, err)
   129  	allEpsRows, err = GetAllRows(root, EpisodesTableName)
   130  	require.NoError(t, err)
   131  	allAppearanceRows, err = GetAllRows(root, AppearancesTableName)
   132  	require.NoError(t, err)
   133  
   134  	assertRowSetsEqual(t, expectedPeople, allPeopleRows)
   135  	assertRowSetsEqual(t, expectedEpisodes, allEpsRows)
   136  	assertRowSetsEqual(t, expectedAppearances, allAppearanceRows)
   137  }
   138  
   139  func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
   140  	t.Skip("Skipped until insert ignore statements supported in go-mysql-server")
   141  
   142  	insertStatements := []string{
   143  		`replace into people (id, first_name, last_name, is_married, age, rating, uuid, num_episodes) values
   144  					(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
   145  		`insert ignore into people values
   146  					(2, "Milhouse", "VanHouten", false, 1, 5.1, '00000000-0000-0000-0000-000000000008', 677)`,
   147  	}
   148  
   149  	dEnv := dtestutils.CreateTestEnv()
   150  	ctx := context.Background()
   151  
   152  	CreateTestDatabase(dEnv, t)
   153  	root, _ := dEnv.WorkingRoot(ctx)
   154  
   155  	db := NewDatabase("dolt", dEnv.DbData())
   156  	engine, sqlCtx, err := NewTestEngine(ctx, db, root)
   157  	require.NoError(t, err)
   158  	DSessFromSess(sqlCtx.Session).EnableBatchedMode()
   159  
   160  	for _, stmt := range insertStatements {
   161  		_, rowIter, err := engine.Query(sqlCtx, stmt)
   162  		require.NoError(t, err)
   163  		drainIter(sqlCtx, rowIter)
   164  	}
   165  
   166  	// Before committing the batch, the database should be unchanged from its original state
   167  	allPeopleRows, err := GetAllRows(root, PeopleTableName)
   168  	assert.NoError(t, err)
   169  	assert.ElementsMatch(t, AllPeopleRows, allPeopleRows)
   170  
   171  	// Now commit the batch and check for new rows
   172  	err = db.Flush(sqlCtx)
   173  	require.NoError(t, err)
   174  
   175  	var expectedPeople []row.Row
   176  
   177  	expectedPeople = append(expectedPeople, AllPeopleRows[1:]...) // skip homer
   178  	expectedPeople = append(expectedPeople,
   179  		NewPeopleRowWithOptionalFields(0, "Maggie", "Simpson", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000007"), 677),
   180  	)
   181  
   182  	allPeopleRows, err = GetAllRows(root, PeopleTableName)
   183  	assert.NoError(t, err)
   184  	assertRowSetsEqual(t, expectedPeople, allPeopleRows)
   185  }
   186  
   187  func TestSqlBatchInsertErrors(t *testing.T) {
   188  	dEnv := dtestutils.CreateTestEnv()
   189  	ctx := context.Background()
   190  
   191  	CreateTestDatabase(dEnv, t)
   192  	root, _ := dEnv.WorkingRoot(ctx)
   193  
   194  	db := NewDatabase("dolt", dEnv.DbData())
   195  	engine, sqlCtx, err := NewTestEngine(ctx, db, root)
   196  	require.NoError(t, err)
   197  	DSessFromSess(sqlCtx.Session).EnableBatchedMode()
   198  
   199  	_, rowIter, err := engine.Query(sqlCtx, `insert into people (id, first_name, last_name, is_married, age, rating, uuid, num_episodes) values
   200  					(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`)
   201  	assert.NoError(t, err)
   202  	assert.Error(t, drainIter(sqlCtx, rowIter))
   203  
   204  	// This generates an error at insert time because of the bad type for the uuid column
   205  	_, rowIter, err = engine.Query(sqlCtx, `insert into people values
   206  					(2, "Milhouse", "VanHouten", false, 1, 5.1, true, 677)`)
   207  	assert.NoError(t, err)
   208  	assert.Error(t, drainIter(sqlCtx, rowIter))
   209  
   210  	assert.NoError(t, db.Flush(sqlCtx))
   211  }
   212  
   213  func assertRowSetsEqual(t *testing.T, expected, actual []row.Row) {
   214  	equal, diff := rowSetsEqual(expected, actual)
   215  	assert.True(t, equal, diff)
   216  }
   217  
   218  // Returns whether the two slices of rows contain the same elements using set semantics (no duplicates), and an error
   219  // string if they aren't.
   220  func rowSetsEqual(expected, actual []row.Row) (bool, string) {
   221  	if len(expected) != len(actual) {
   222  		return false, fmt.Sprintf("Sets have different sizes: expected %d, was %d", len(expected), len(actual))
   223  	}
   224  
   225  	for _, ex := range expected {
   226  		if !containsRow(actual, ex) {
   227  			return false, fmt.Sprintf("Missing row: %v", ex)
   228  		}
   229  	}
   230  
   231  	return true, ""
   232  }
   233  
   234  func containsRow(rs []row.Row, r row.Row) bool {
   235  	for _, r2 := range rs {
   236  		equal, _ := rowsEqual(r, r2)
   237  		if equal {
   238  			return true
   239  		}
   240  	}
   241  	return false
   242  }
   243  
   244  func rowsEqual(expected, actual row.Row) (bool, string) {
   245  	er, ar := make(map[uint64]types.Value), make(map[uint64]types.Value)
   246  	_, err := expected.IterCols(func(t uint64, v types.Value) (bool, error) {
   247  		er[t] = v
   248  		return false, nil
   249  	})
   250  
   251  	if err != nil {
   252  		panic(err)
   253  	}
   254  
   255  	_, err = actual.IterCols(func(t uint64, v types.Value) (bool, error) {
   256  		ar[t] = v
   257  		return false, nil
   258  	})
   259  
   260  	if err != nil {
   261  		panic(err)
   262  	}
   263  
   264  	opts := cmp.Options{cmp.AllowUnexported(), dtestutils.FloatComparer, dtestutils.TimestampComparer}
   265  	eq := cmp.Equal(er, ar, opts)
   266  	var diff string
   267  	if !eq {
   268  		diff = cmp.Diff(er, ar, opts)
   269  	}
   270  	return eq, diff
   271  }
   272  
   273  func newPeopleRow(id int, firstName, lastName string) row.Row {
   274  	vals := row.TaggedValues{
   275  		IdTag:        types.Int(id),
   276  		FirstNameTag: types.String(firstName),
   277  		LastNameTag:  types.String(lastName),
   278  	}
   279  
   280  	r, err := row.New(types.Format_Default, PeopleTestSchema, vals)
   281  
   282  	if err != nil {
   283  		panic(err)
   284  	}
   285  
   286  	return r
   287  }
   288  
   289  func newEpsRow(id int, name string) row.Row {
   290  	vals := row.TaggedValues{
   291  		EpisodeIdTag: types.Int(id),
   292  		EpNameTag:    types.String(name),
   293  	}
   294  
   295  	r, err := row.New(types.Format_Default, EpisodesTestSchema, vals)
   296  
   297  	if err != nil {
   298  		panic(err)
   299  	}
   300  
   301  	return r
   302  }
   303  
   304  func newAppsRow(charId int, epId int) row.Row {
   305  	vals := row.TaggedValues{
   306  		AppCharacterTag: types.Int(charId),
   307  		AppEpTag:        types.Int(epId),
   308  	}
   309  
   310  	r, err := row.New(types.Format_Default, AppearancesTestSchema, vals)
   311  
   312  	if err != nil {
   313  		panic(err)
   314  	}
   315  
   316  	return r
   317  }