github.com/wrgl/wrgl@v0.14.0/pkg/merge/helpers/utils.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright © 2022 Wrangle Ltd
     3  
     4  package mergehelpers
     5  
     6  import (
     7  	"context"
     8  	"sort"
     9  	"testing"
    10  
    11  	"github.com/go-logr/logr/testr"
    12  	"github.com/stretchr/testify/require"
    13  	"github.com/wrgl/wrgl/pkg/diff"
    14  	"github.com/wrgl/wrgl/pkg/index"
    15  	"github.com/wrgl/wrgl/pkg/merge"
    16  	"github.com/wrgl/wrgl/pkg/misc"
    17  	"github.com/wrgl/wrgl/pkg/objects"
    18  	"github.com/wrgl/wrgl/pkg/ref"
    19  	"github.com/wrgl/wrgl/pkg/sorter"
    20  )
    21  
    22  func CreateCollector(t *testing.T, db objects.Store, baseCom *objects.Commit) *merge.RowCollector {
    23  	t.Helper()
    24  	discardedRows, err := index.NewHashSet(misc.NewBuffer(nil), 0)
    25  	require.NoError(t, err)
    26  	baseT, err := objects.GetTable(db, baseCom.Table)
    27  	require.NoError(t, err)
    28  	collector, err := merge.NewCollector(db, baseT, discardedRows)
    29  	require.NoError(t, err)
    30  	return collector
    31  }
    32  
    33  func CollectUnresolvedMerges(t *testing.T, merger *merge.Merger) []*merge.Merge {
    34  	t.Helper()
    35  	mergeCh, err := merger.Start()
    36  	require.NoError(t, err)
    37  	merges := []*merge.Merge{}
    38  	for m := range mergeCh {
    39  		merges = append(merges, m)
    40  	}
    41  	sort.SliceStable(merges, func(i, j int) bool {
    42  		if merges[i].ColDiff != nil && merges[j].ColDiff == nil {
    43  			return true
    44  		}
    45  		if merges[j].ColDiff != nil && merges[i].ColDiff == nil {
    46  			return false
    47  		}
    48  		if merges[i].Base == nil && merges[j].Base != nil {
    49  			return true
    50  		}
    51  		if merges[j].Base == nil && merges[i].Base != nil {
    52  			return false
    53  		}
    54  		return string(merges[i].Base) < string(merges[j].Base)
    55  	})
    56  	return merges
    57  }
    58  
    59  func CollectSortedRows(t *testing.T, merger *merge.Merger, removedCols map[int]struct{}) []*sorter.Rows {
    60  	t.Helper()
    61  	rows := []*sorter.Rows{}
    62  	ctx, cancel := context.WithCancel(context.Background())
    63  	defer cancel()
    64  	ch, err := merger.SortedRows(ctx, removedCols)
    65  	require.NoError(t, err)
    66  	for blk := range ch {
    67  		rows = append(rows, blk)
    68  	}
    69  	require.NoError(t, merger.Error())
    70  	return rows
    71  }
    72  
    73  func CollectSortedBlocks(t *testing.T, merger *merge.Merger, removedCols map[int]struct{}) []*sorter.Block {
    74  	t.Helper()
    75  	rows := []*sorter.Block{}
    76  	ctx, cancel := context.WithCancel(context.Background())
    77  	defer cancel()
    78  	ch, err := merger.SortedBlocks(ctx, removedCols)
    79  	require.NoError(t, err)
    80  	for blk := range ch {
    81  		rows = append(rows, blk)
    82  	}
    83  	require.NoError(t, merger.Error())
    84  	return rows
    85  }
    86  
    87  func CreateMerger(t *testing.T, db objects.Store, commits ...[]byte) (*merge.Merger, *diff.BlockBuffer) {
    88  	base, err := ref.SeekCommonAncestor(db, commits...)
    89  	require.NoError(t, err)
    90  	baseCom, err := objects.GetCommit(db, base)
    91  	require.NoError(t, err)
    92  	baseT, err := objects.GetTable(db, baseCom.Table)
    93  	require.NoError(t, err)
    94  	otherTs := make([]*objects.Table, len(commits))
    95  	otherSums := make([][]byte, len(commits))
    96  	for i, sum := range commits {
    97  		com, err := objects.GetCommit(db, sum)
    98  		require.NoError(t, err)
    99  		otherT, err := objects.GetTable(db, com.Table)
   100  		require.NoError(t, err)
   101  		otherTs[i] = otherT
   102  		otherSums[i] = com.Table
   103  	}
   104  	collector := CreateCollector(t, db, baseCom)
   105  	buf, err := diff.BlockBufferWithSingleStore(db, append([]*objects.Table{baseT}, otherTs...))
   106  	require.NoError(t, err)
   107  	merger, err := merge.NewMerger(db, collector, buf, 0, baseT, otherTs, baseCom.Table, otherSums, testr.New(t))
   108  	require.NoError(t, err)
   109  	return merger, buf
   110  }