github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/merge/row_merge_test.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package merge
    16  
    17  import (
    18  	"context"
    19  	"strconv"
    20  	"testing"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/stretchr/testify/assert"
    24  
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    26  	"github.com/dolthub/dolt/go/store/types"
    27  	"github.com/dolthub/dolt/go/store/val"
    28  )
    29  
    30  type nomsRowMergeTest struct {
    31  	name                  string
    32  	row, mergeRow, ancRow types.Value
    33  	sch                   schema.Schema
    34  	expectedResult        types.Value
    35  	expectCellMerge       bool
    36  	expectConflict        bool
    37  }
    38  
    39  type rowMergeTest struct {
    40  	name                                  string
    41  	row, mergeRow, ancRow                 val.Tuple
    42  	mergedSch, leftSch, rightSch, baseSch schema.Schema
    43  	expectedResult                        val.Tuple
    44  	expectCellMerge                       bool
    45  	expectConflict                        bool
    46  }
    47  
    48  type testCase struct {
    49  	name                     string
    50  	row, mergeRow, ancRow    []*int
    51  	rowCnt, mRowCnt, aRowCnt int
    52  	expectedResult           []*int
    53  	expectCellMerge          bool
    54  	expectConflict           bool
    55  }
    56  
    57  // 0 is nil, negative value is invalid
    58  func build(ints ...int) []*int {
    59  	out := make([]*int, len(ints))
    60  	for i, v := range ints {
    61  		if v < 0 {
    62  			panic("invalid")
    63  		}
    64  		if v == 0 {
    65  			continue
    66  		}
    67  		t := v
    68  		out[i] = &t
    69  	}
    70  	return out
    71  }
    72  
    73  var convergentEditCases = []testCase{
    74  	{
    75  		"add same row",
    76  		build(1, 2),
    77  		build(1, 2),
    78  		nil,
    79  		2, 2, 2,
    80  		build(1, 2),
    81  		false,
    82  		false,
    83  	},
    84  	{
    85  		"both delete row",
    86  		nil,
    87  		nil,
    88  		build(1, 2),
    89  		2, 2, 2,
    90  		nil,
    91  		false,
    92  		false,
    93  	},
    94  	{
    95  		"modify row to equal value",
    96  		build(2, 2),
    97  		build(2, 2),
    98  		build(1, 1),
    99  		2, 2, 2,
   100  		build(2, 2),
   101  		false,
   102  		false,
   103  	},
   104  }
   105  
   106  var testCases = []testCase{
   107  	{
   108  		"insert different rows",
   109  		build(1, 2),
   110  		build(2, 3),
   111  		nil,
   112  		2, 2, 2,
   113  		nil,
   114  		false,
   115  		true,
   116  	},
   117  	{
   118  		"delete a row in one, and modify it in other",
   119  		nil,
   120  		build(1, 3),
   121  		build(1, 2),
   122  		2, 2, 2,
   123  		nil,
   124  		false,
   125  		true,
   126  	},
   127  	{
   128  		"modify rows without overlap",
   129  		build(2, 1),
   130  		build(1, 2),
   131  		build(1, 1),
   132  		2, 2, 2,
   133  		build(2, 2),
   134  		true,
   135  		false,
   136  	},
   137  	{
   138  		"modify rows with equal overlapping changes",
   139  		build(2, 2, 255),
   140  		build(2, 3, 255),
   141  		build(1, 2, 0),
   142  		3, 3, 3,
   143  		build(2, 3, 255),
   144  		true,
   145  		false,
   146  	},
   147  	{
   148  		"modify rows with differing overlapping changes",
   149  		build(2, 2, 128),
   150  		build(1, 3, 255),
   151  		build(1, 2, 0),
   152  		3, 3, 3,
   153  		nil,
   154  		false,
   155  		true,
   156  	},
   157  	{
   158  		"modify rows where one adds a column",
   159  		build(2, 2),
   160  		build(1, 3, 255),
   161  		build(1, 2),
   162  		2, 3, 2,
   163  		build(2, 3, 255),
   164  		true,
   165  		false,
   166  	},
   167  	{
   168  		"modify rows where one drops a column",
   169  		build(1, 2, 1),
   170  		build(2, 1),
   171  		build(1, 1, 1),
   172  		3, 2, 3,
   173  		build(2, 2),
   174  		true,
   175  		false,
   176  	},
   177  	// TODO (dhruv): need to fix this test case for new storage format
   178  	//{
   179  	//	"add rows but one holds a new column",
   180  	//	build(1, 1),
   181  	//	build(1, 1, 1),
   182  	//	nil,
   183  	//	2, 3, 2,
   184  	//	build(1, 1, 1),
   185  	//	true,
   186  	//	false,
   187  	//},
   188  	{
   189  		"Delete a row in one, set all null in the other",
   190  		build(0, 0, 0), // build translates zeros into NULL values
   191  		nil,
   192  		build(1, 1, 1),
   193  		3, 3, 3,
   194  		nil,
   195  		false,
   196  		true,
   197  	},
   198  }
   199  
   200  func TestRowMerge(t *testing.T) {
   201  	if types.Format_Default != types.Format_DOLT {
   202  		t.Skip()
   203  	}
   204  
   205  	ctx := sql.NewEmptyContext()
   206  
   207  	tests := make([]rowMergeTest, len(testCases))
   208  	for i, t := range testCases {
   209  		tests[i] = createRowMergeStruct(t)
   210  	}
   211  
   212  	for _, test := range tests {
   213  		t.Run(test.name, func(t *testing.T) {
   214  			v := newValueMerger(test.mergedSch, test.leftSch, test.rightSch, test.baseSch, syncPool, nil)
   215  
   216  			merged, ok, err := v.tryMerge(ctx, test.row, test.mergeRow, test.ancRow)
   217  			assert.NoError(t, err)
   218  			assert.Equal(t, test.expectConflict, !ok)
   219  			vD := test.mergedSch.GetValueDescriptor()
   220  			assert.Equal(t, vD.Format(test.expectedResult), vD.Format(merged))
   221  		})
   222  	}
   223  }
   224  
   225  func TestNomsRowMerge(t *testing.T) {
   226  	if types.Format_Default == types.Format_DOLT {
   227  		t.Skip()
   228  	}
   229  
   230  	testCases := append(testCases, convergentEditCases...)
   231  	tests := make([]nomsRowMergeTest, len(testCases))
   232  	for i, t := range testCases {
   233  		tests[i] = createNomsRowMergeStruct(t)
   234  	}
   235  
   236  	for _, test := range tests {
   237  		t.Run(test.name, func(t *testing.T) {
   238  			rowMergeResult, err := nomsPkRowMerge(context.Background(), types.Format_Default, test.sch, test.row, test.mergeRow, test.ancRow)
   239  			assert.NoError(t, err)
   240  			assert.Equal(t, test.expectedResult, rowMergeResult.mergedRow,
   241  				"expected "+mustString(types.EncodedValue(context.Background(), test.expectedResult))+
   242  					"got "+mustString(types.EncodedValue(context.Background(), rowMergeResult.mergedRow)))
   243  			assert.Equal(t, test.expectCellMerge, rowMergeResult.didCellMerge)
   244  			assert.Equal(t, test.expectConflict, rowMergeResult.isConflict)
   245  		})
   246  	}
   247  }
   248  
   249  func valsToTestTupleWithoutPks(vals []types.Value) types.Value {
   250  	return valsToTestTuple(vals, false)
   251  }
   252  
   253  func valsToTestTupleWithPks(vals []types.Value) types.Value {
   254  	return valsToTestTuple(vals, true)
   255  }
   256  
   257  func valsToTestTuple(vals []types.Value, includePrimaryKeys bool) types.Value {
   258  	if vals == nil {
   259  		return nil
   260  	}
   261  
   262  	tplVals := make([]types.Value, 0, 2*len(vals))
   263  	for i, val := range vals {
   264  		if !types.IsNull(val) {
   265  			tag := i
   266  			// Assume one primary key tag, add 1 to all other tags
   267  			if includePrimaryKeys {
   268  				tag++
   269  			}
   270  			tplVals = append(tplVals, types.Uint(tag))
   271  			tplVals = append(tplVals, val)
   272  		}
   273  	}
   274  
   275  	return mustTuple(types.NewTuple(types.Format_Default, tplVals...))
   276  }
   277  
   278  func createRowMergeStruct(t testCase) rowMergeTest {
   279  	mergedSch := calcMergedSchema(t)
   280  	leftSch := calcSchema(t.rowCnt)
   281  	rightSch := calcSchema(t.mRowCnt)
   282  	baseSch := calcSchema(t.aRowCnt)
   283  
   284  	tpl := buildTup(leftSch, t.row)
   285  	mergeTpl := buildTup(rightSch, t.mergeRow)
   286  	ancTpl := buildTup(baseSch, t.ancRow)
   287  	expectedTpl := buildTup(mergedSch, t.expectedResult)
   288  	return rowMergeTest{
   289  		t.name,
   290  		tpl, mergeTpl, ancTpl,
   291  		mergedSch, leftSch, rightSch, baseSch,
   292  		expectedTpl,
   293  		t.expectCellMerge,
   294  		t.expectConflict}
   295  }
   296  
   297  func createNomsRowMergeStruct(t testCase) nomsRowMergeTest {
   298  	sch := calcMergedSchema(t)
   299  
   300  	tpl := valsToTestTupleWithPks(toVals(t.row))
   301  	mergeTpl := valsToTestTupleWithPks(toVals(t.mergeRow))
   302  	ancTpl := valsToTestTupleWithPks(toVals(t.ancRow))
   303  	expectedTpl := valsToTestTupleWithPks(toVals(t.expectedResult))
   304  	return nomsRowMergeTest{t.name, tpl, mergeTpl, ancTpl, sch, expectedTpl, t.expectCellMerge, t.expectConflict}
   305  }
   306  
   307  func calcMergedSchema(t testCase) schema.Schema {
   308  	longest := t.rowCnt
   309  	if t.mRowCnt > longest {
   310  		longest = t.mRowCnt
   311  	}
   312  	if t.aRowCnt > longest {
   313  		longest = t.aRowCnt
   314  	}
   315  
   316  	return calcSchema(longest)
   317  }
   318  
   319  func calcSchema(nCols int) schema.Schema {
   320  	cols := make([]schema.Column, nCols+1)
   321  	// Schema needs a primary key to be valid, but all the logic being tested works only on the non-key columns.
   322  	cols[0] = schema.NewColumn("primaryKey", 0, types.IntKind, true)
   323  	for i := 0; i < nCols; i++ {
   324  		tag := i + 1
   325  		cols[tag] = schema.NewColumn(strconv.FormatInt(int64(tag), 10), uint64(tag), types.IntKind, false)
   326  	}
   327  
   328  	colColl := schema.NewColCollection(cols...)
   329  	sch := schema.MustSchemaFromCols(colColl)
   330  	return sch
   331  }
   332  
   333  func buildTup(sch schema.Schema, r []*int) val.Tuple {
   334  	if r == nil {
   335  		return nil
   336  	}
   337  
   338  	vD := sch.GetValueDescriptor()
   339  	vB := val.NewTupleBuilder(vD)
   340  	for i, v := range r {
   341  		if v != nil {
   342  			vB.PutInt64(i, int64(*v))
   343  		}
   344  	}
   345  	return vB.Build(syncPool)
   346  }
   347  
   348  func toVals(ints []*int) []types.Value {
   349  	if ints == nil {
   350  		return nil
   351  	}
   352  
   353  	v := make([]types.Value, len(ints))
   354  	for i, d := range ints {
   355  		if d == nil {
   356  			v[i] = types.NullValue
   357  			continue
   358  		}
   359  
   360  		v[i] = types.Int(*d)
   361  	}
   362  	return v
   363  }