github.com/wrgl/wrgl@v0.14.0/pkg/merge/merger.go (about) 1 // SPDX-License-Identifier: Apache-2.0 2 // Copyright © 2022 Wrangle Ltd 3 4 package merge 5 6 import ( 7 "bytes" 8 "context" 9 "fmt" 10 "reflect" 11 "time" 12 13 "github.com/go-logr/logr" 14 "github.com/wrgl/wrgl/pkg/diff" 15 "github.com/wrgl/wrgl/pkg/objects" 16 "github.com/wrgl/wrgl/pkg/progress" 17 "github.com/wrgl/wrgl/pkg/sorter" 18 ) 19 20 type Merger struct { 21 db objects.Store 22 errChan chan error 23 progressPeriod time.Duration 24 Progress progress.Tracker 25 baseT *objects.Table 26 otherTs []*objects.Table 27 baseSum []byte 28 otherSums [][]byte 29 buf *diff.BlockBuffer 30 collector *RowCollector 31 logger logr.Logger 32 } 33 34 func NewMerger( 35 db objects.Store, 36 collector *RowCollector, 37 buf *diff.BlockBuffer, 38 progressPeriod time.Duration, 39 baseT *objects.Table, 40 otherTs []*objects.Table, 41 baseSum []byte, 42 otherSums [][]byte, 43 logger logr.Logger, 44 ) (m *Merger, err error) { 45 m = &Merger{ 46 db: db, 47 errChan: make(chan error, len(otherTs)), 48 progressPeriod: progressPeriod, 49 baseT: baseT, 50 otherTs: otherTs, 51 baseSum: baseSum, 52 otherSums: otherSums, 53 collector: collector, 54 buf: buf, 55 logger: logger.WithName("Merger"), 56 } 57 return 58 } 59 60 func (m *Merger) Error() error { 61 close(m.errChan) 62 err, ok := <-m.errChan 63 if !ok { 64 return nil 65 } 66 return err 67 } 68 69 func strSliceEqual(left, right []string) bool { 70 if len(left) != len(right) { 71 return false 72 } 73 for i, s := range left { 74 if s != right[i] { 75 return false 76 } 77 } 78 return true 79 } 80 81 func (m *Merger) mergeTables(colDiff *diff.ColDiff, mergeChan chan<- *Merge, errChan chan<- error, diffChans ...<-chan *objects.Diff) { 82 var ( 83 n = len(diffChans) 84 cases = make([]reflect.SelectCase, n) 85 closed = make([]bool, n) 86 merges = map[string]*Merge{} 87 counter = map[string]int{} 88 resolver = NewRowResolver(m.db, colDiff, m.buf) 89 ) 90 91 mergeChan <- &Merge{ 92 ColDiff: colDiff, 93 } 94 defer close(mergeChan) 95 96 for i, ch := range diffChans { 97 cases[i] = reflect.SelectCase{ 98 Dir: reflect.SelectRecv, 99 Chan: reflect.ValueOf(ch), 100 } 101 } 102 103 for { 104 chosen, recv, ok := reflect.Select(cases) 105 if !ok { 106 closed[chosen] = true 107 allClosed := true 108 for _, b := range closed { 109 if !b { 110 allClosed = false 111 break 112 } 113 } 114 if allClosed { 115 break 116 } 117 continue 118 } 119 d := recv.Interface().(*objects.Diff) 120 pkSum := string(d.PK) 121 if m, ok := merges[pkSum]; !ok { 122 merges[pkSum] = &Merge{ 123 PK: d.PK, 124 Base: d.OldSum, 125 BaseOffset: d.OldOffset, 126 Others: make([][]byte, n), 127 OtherOffsets: make([]uint32, n), 128 } 129 merges[pkSum].Others[chosen] = d.Sum 130 merges[pkSum].OtherOffsets[chosen] = d.Offset 131 counter[pkSum] = 1 132 } else { 133 m.Others[chosen] = d.Sum 134 m.OtherOffsets[chosen] = d.Offset 135 counter[pkSum]++ 136 } 137 } 138 for _, obj := range merges { 139 if obj.Base != nil { 140 noChanges := true 141 for _, b := range obj.Others { 142 if !bytes.Equal(b, obj.Base) { 143 noChanges = false 144 break 145 } 146 } 147 if noChanges { 148 continue 149 } 150 } 151 err := resolver.Resolve(obj) 152 if err != nil { 153 errChan <- fmt.Errorf("resolve error: %v", err) 154 return 155 } 156 mergeChan <- obj 157 } 158 } 159 160 func (m *Merger) Start() (ch <-chan *Merge, err error) { 161 n := len(m.otherTs) 162 var pk []string 163 for _, t := range m.otherTs { 164 if pk == nil { 165 pk = t.PrimaryKey() 166 } else if !strSliceEqual(pk, t.PrimaryKey()) { 167 return nil, fmt.Errorf("can't merge: primary key differs between versions") 168 } 169 } 170 mergeChan := make(chan *Merge) 171 diffs := make([]<-chan *objects.Diff, n) 172 progs := make([]progress.Tracker, n) 173 cols := make([][2][]string, n) 174 baseIdx, err := objects.GetTableIndex(m.db, m.baseSum) 175 if err != nil { 176 return nil, err 177 } 178 for i, t := range m.otherTs { 179 idx, err := objects.GetTableIndex(m.db, m.otherSums[i]) 180 if err != nil { 181 return nil, err 182 } 183 diffChan, progTracker := diff.DiffTables( 184 m.db, m.db, t, m.baseT, idx, baseIdx, m.errChan, m.logger, 185 diff.WithProgressInterval(m.progressPeriod*time.Duration(n)), 186 diff.WithEmitUnchangedRow(), 187 ) 188 diffs[i] = diffChan 189 progs[i] = progTracker 190 cols[i] = [2][]string{t.Columns, t.PrimaryKey()} 191 } 192 colDiff := diff.CompareColumns([2][]string{m.baseT.Columns, m.baseT.PrimaryKey()}, cols...) 193 m.Progress = progress.JoinTrackers(progs...) 194 go m.mergeTables(colDiff, mergeChan, m.errChan, diffs...) 195 return m.collector.CollectResolvedRow(m.errChan, mergeChan), nil 196 } 197 198 func (m *Merger) SaveResolvedRow(pk []byte, row []string) error { 199 return m.collector.SaveResolvedRow(pk, row) 200 } 201 202 func (m *Merger) SortedBlocks(ctx context.Context, removedCols map[int]struct{}) (<-chan *sorter.Block, error) { 203 m.errChan = make(chan error, 1) 204 return m.collector.SortedBlocks(ctx, removedCols, m.errChan) 205 } 206 207 func (m *Merger) SortedRows(ctx context.Context, removedCols map[int]struct{}) (<-chan *sorter.Rows, error) { 208 m.errChan = make(chan error, 1) 209 return m.collector.SortedRows(ctx, removedCols, m.errChan) 210 } 211 212 func (m *Merger) Columns(removedCols map[int]struct{}) []string { 213 return m.collector.Columns(removedCols) 214 } 215 216 func (m *Merger) PK() []string { 217 return m.collector.PK() 218 } 219 220 func (m *Merger) Close() error { 221 return m.collector.Close() 222 }