github.com/wrgl/wrgl@v0.14.0/pkg/merge/merger.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright © 2022 Wrangle Ltd
     3  
     4  package merge
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"fmt"
    10  	"reflect"
    11  	"time"
    12  
    13  	"github.com/go-logr/logr"
    14  	"github.com/wrgl/wrgl/pkg/diff"
    15  	"github.com/wrgl/wrgl/pkg/objects"
    16  	"github.com/wrgl/wrgl/pkg/progress"
    17  	"github.com/wrgl/wrgl/pkg/sorter"
    18  )
    19  
    20  type Merger struct {
    21  	db             objects.Store
    22  	errChan        chan error
    23  	progressPeriod time.Duration
    24  	Progress       progress.Tracker
    25  	baseT          *objects.Table
    26  	otherTs        []*objects.Table
    27  	baseSum        []byte
    28  	otherSums      [][]byte
    29  	buf            *diff.BlockBuffer
    30  	collector      *RowCollector
    31  	logger         logr.Logger
    32  }
    33  
    34  func NewMerger(
    35  	db objects.Store,
    36  	collector *RowCollector,
    37  	buf *diff.BlockBuffer,
    38  	progressPeriod time.Duration,
    39  	baseT *objects.Table,
    40  	otherTs []*objects.Table,
    41  	baseSum []byte,
    42  	otherSums [][]byte,
    43  	logger logr.Logger,
    44  ) (m *Merger, err error) {
    45  	m = &Merger{
    46  		db:             db,
    47  		errChan:        make(chan error, len(otherTs)),
    48  		progressPeriod: progressPeriod,
    49  		baseT:          baseT,
    50  		otherTs:        otherTs,
    51  		baseSum:        baseSum,
    52  		otherSums:      otherSums,
    53  		collector:      collector,
    54  		buf:            buf,
    55  		logger:         logger.WithName("Merger"),
    56  	}
    57  	return
    58  }
    59  
    60  func (m *Merger) Error() error {
    61  	close(m.errChan)
    62  	err, ok := <-m.errChan
    63  	if !ok {
    64  		return nil
    65  	}
    66  	return err
    67  }
    68  
    69  func strSliceEqual(left, right []string) bool {
    70  	if len(left) != len(right) {
    71  		return false
    72  	}
    73  	for i, s := range left {
    74  		if s != right[i] {
    75  			return false
    76  		}
    77  	}
    78  	return true
    79  }
    80  
    81  func (m *Merger) mergeTables(colDiff *diff.ColDiff, mergeChan chan<- *Merge, errChan chan<- error, diffChans ...<-chan *objects.Diff) {
    82  	var (
    83  		n        = len(diffChans)
    84  		cases    = make([]reflect.SelectCase, n)
    85  		closed   = make([]bool, n)
    86  		merges   = map[string]*Merge{}
    87  		counter  = map[string]int{}
    88  		resolver = NewRowResolver(m.db, colDiff, m.buf)
    89  	)
    90  
    91  	mergeChan <- &Merge{
    92  		ColDiff: colDiff,
    93  	}
    94  	defer close(mergeChan)
    95  
    96  	for i, ch := range diffChans {
    97  		cases[i] = reflect.SelectCase{
    98  			Dir:  reflect.SelectRecv,
    99  			Chan: reflect.ValueOf(ch),
   100  		}
   101  	}
   102  
   103  	for {
   104  		chosen, recv, ok := reflect.Select(cases)
   105  		if !ok {
   106  			closed[chosen] = true
   107  			allClosed := true
   108  			for _, b := range closed {
   109  				if !b {
   110  					allClosed = false
   111  					break
   112  				}
   113  			}
   114  			if allClosed {
   115  				break
   116  			}
   117  			continue
   118  		}
   119  		d := recv.Interface().(*objects.Diff)
   120  		pkSum := string(d.PK)
   121  		if m, ok := merges[pkSum]; !ok {
   122  			merges[pkSum] = &Merge{
   123  				PK:           d.PK,
   124  				Base:         d.OldSum,
   125  				BaseOffset:   d.OldOffset,
   126  				Others:       make([][]byte, n),
   127  				OtherOffsets: make([]uint32, n),
   128  			}
   129  			merges[pkSum].Others[chosen] = d.Sum
   130  			merges[pkSum].OtherOffsets[chosen] = d.Offset
   131  			counter[pkSum] = 1
   132  		} else {
   133  			m.Others[chosen] = d.Sum
   134  			m.OtherOffsets[chosen] = d.Offset
   135  			counter[pkSum]++
   136  		}
   137  	}
   138  	for _, obj := range merges {
   139  		if obj.Base != nil {
   140  			noChanges := true
   141  			for _, b := range obj.Others {
   142  				if !bytes.Equal(b, obj.Base) {
   143  					noChanges = false
   144  					break
   145  				}
   146  			}
   147  			if noChanges {
   148  				continue
   149  			}
   150  		}
   151  		err := resolver.Resolve(obj)
   152  		if err != nil {
   153  			errChan <- fmt.Errorf("resolve error: %v", err)
   154  			return
   155  		}
   156  		mergeChan <- obj
   157  	}
   158  }
   159  
   160  func (m *Merger) Start() (ch <-chan *Merge, err error) {
   161  	n := len(m.otherTs)
   162  	var pk []string
   163  	for _, t := range m.otherTs {
   164  		if pk == nil {
   165  			pk = t.PrimaryKey()
   166  		} else if !strSliceEqual(pk, t.PrimaryKey()) {
   167  			return nil, fmt.Errorf("can't merge: primary key differs between versions")
   168  		}
   169  	}
   170  	mergeChan := make(chan *Merge)
   171  	diffs := make([]<-chan *objects.Diff, n)
   172  	progs := make([]progress.Tracker, n)
   173  	cols := make([][2][]string, n)
   174  	baseIdx, err := objects.GetTableIndex(m.db, m.baseSum)
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	for i, t := range m.otherTs {
   179  		idx, err := objects.GetTableIndex(m.db, m.otherSums[i])
   180  		if err != nil {
   181  			return nil, err
   182  		}
   183  		diffChan, progTracker := diff.DiffTables(
   184  			m.db, m.db, t, m.baseT, idx, baseIdx, m.errChan, m.logger,
   185  			diff.WithProgressInterval(m.progressPeriod*time.Duration(n)),
   186  			diff.WithEmitUnchangedRow(),
   187  		)
   188  		diffs[i] = diffChan
   189  		progs[i] = progTracker
   190  		cols[i] = [2][]string{t.Columns, t.PrimaryKey()}
   191  	}
   192  	colDiff := diff.CompareColumns([2][]string{m.baseT.Columns, m.baseT.PrimaryKey()}, cols...)
   193  	m.Progress = progress.JoinTrackers(progs...)
   194  	go m.mergeTables(colDiff, mergeChan, m.errChan, diffs...)
   195  	return m.collector.CollectResolvedRow(m.errChan, mergeChan), nil
   196  }
   197  
   198  func (m *Merger) SaveResolvedRow(pk []byte, row []string) error {
   199  	return m.collector.SaveResolvedRow(pk, row)
   200  }
   201  
   202  func (m *Merger) SortedBlocks(ctx context.Context, removedCols map[int]struct{}) (<-chan *sorter.Block, error) {
   203  	m.errChan = make(chan error, 1)
   204  	return m.collector.SortedBlocks(ctx, removedCols, m.errChan)
   205  }
   206  
   207  func (m *Merger) SortedRows(ctx context.Context, removedCols map[int]struct{}) (<-chan *sorter.Rows, error) {
   208  	m.errChan = make(chan error, 1)
   209  	return m.collector.SortedRows(ctx, removedCols, m.errChan)
   210  }
   211  
   212  func (m *Merger) Columns(removedCols map[int]struct{}) []string {
   213  	return m.collector.Columns(removedCols)
   214  }
   215  
   216  func (m *Merger) PK() []string {
   217  	return m.collector.PK()
   218  }
   219  
   220  func (m *Merger) Close() error {
   221  	return m.collector.Close()
   222  }