github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/prolly/tree/merge.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tree 16 17 import ( 18 "context" 19 "io" 20 21 "golang.org/x/sync/errgroup" 22 23 "github.com/dolthub/dolt/go/store/prolly/message" 24 ) 25 26 const patchBufferSize = 1024 27 28 // CollisionFn is a callback that handles 3-way merging of NodeItems when any 29 // key collision occurs. A typical implementation will attempt a cell-wise merge 30 // of the tuples, or register a conflict if such a merge is not possible. 31 type CollisionFn func(left, right Diff) (Diff, bool) 32 33 type MergeStats struct { 34 Adds int 35 Modifications int 36 Removes int 37 } 38 39 // ThreeWayMerge implements a three-way merge algorithm using |base| as the common ancestor, |right| as 40 // the source branch, and |left| as the destination branch. Both |left| and |right| are diff'd against 41 // |base| to compute merge patches, but rather than applying both sets of patches to |base|, patches from 42 // |right| are applied directly to |left|. This reduces the amount of write work and improves performance. 43 // In the case that a key-value pair was modified on both |left| and |right| with different resulting 44 // values, the CollisionFn is called to perform a cell-wise merge, or to throw a conflict. 45 func ThreeWayMerge[K ~[]byte, O Ordering[K], S message.Serializer]( 46 ctx context.Context, 47 ns NodeStore, 48 left, right, base Node, 49 collide CollisionFn, 50 leftSchemaChange, rightSchemaChange bool, 51 order O, 52 serializer S, 53 ) (final Node, stats MergeStats, err error) { 54 ld, err := DifferFromRoots[K](ctx, ns, ns, base, left, order, leftSchemaChange) 55 if err != nil { 56 return Node{}, MergeStats{}, err 57 } 58 59 rd, err := DifferFromRoots[K](ctx, ns, ns, base, right, order, rightSchemaChange) 60 if err != nil { 61 return Node{}, MergeStats{}, err 62 } 63 64 eg, ctx := errgroup.WithContext(ctx) 65 patches := newPatchBuffer(patchBufferSize) 66 67 // iterate |ld| and |rd| in parallel, populating |patches| 68 eg.Go(func() (err error) { 69 defer func() { 70 if cerr := patches.Close(); err == nil { 71 err = cerr 72 } 73 }() 74 stats, err = sendPatches(ctx, ld, rd, patches, collide) 75 return 76 }) 77 78 // consume |patches| and apply them to |left| 79 eg.Go(func() error { 80 final, err = ApplyMutations[K](ctx, ns, left, order, serializer, patches) 81 return err 82 }) 83 84 if err = eg.Wait(); err != nil { 85 return Node{}, MergeStats{}, err 86 } 87 88 return final, stats, nil 89 } 90 91 // patchBuffer implements MutationIter. It consumes Diffs 92 // from the parallel treeDiffers and transforms them into 93 // patches for the chunker to apply. 94 type patchBuffer struct { 95 buf chan patch 96 } 97 98 var _ MutationIter = patchBuffer{} 99 100 type patch [2]Item 101 102 func newPatchBuffer(sz int) patchBuffer { 103 return patchBuffer{buf: make(chan patch, sz)} 104 } 105 106 func (ps patchBuffer) sendPatch(ctx context.Context, diff Diff) error { 107 p := patch{diff.Key, diff.To} 108 select { 109 case <-ctx.Done(): 110 return ctx.Err() 111 case ps.buf <- p: 112 return nil 113 } 114 } 115 116 // NextMutation implements MutationIter. 117 func (ps patchBuffer) NextMutation(ctx context.Context) (Item, Item) { 118 var p patch 119 select { 120 case p = <-ps.buf: 121 return p[0], p[1] 122 case <-ctx.Done(): 123 return nil, nil 124 } 125 } 126 127 func (ps patchBuffer) Close() error { 128 close(ps.buf) 129 return nil 130 } 131 132 func sendPatches[K ~[]byte, O Ordering[K]]( 133 ctx context.Context, 134 l, r Differ[K, O], 135 buf patchBuffer, 136 cb CollisionFn, 137 ) (stats MergeStats, err error) { 138 var ( 139 left, right Diff 140 lok, rok = true, true 141 ) 142 143 left, err = l.Next(ctx) 144 if err == io.EOF { 145 err, lok = nil, false 146 } 147 if err != nil { 148 return MergeStats{}, err 149 } 150 151 right, err = r.Next(ctx) 152 if err == io.EOF { 153 err, rok = nil, false 154 } 155 if err != nil { 156 return MergeStats{}, err 157 } 158 159 for lok && rok { 160 cmp := l.order.Compare(K(left.Key), K(right.Key)) 161 162 switch { 163 case cmp < 0: 164 // already in left 165 left, err = l.Next(ctx) 166 if err == io.EOF { 167 err, lok = nil, false 168 } 169 if err != nil { 170 return MergeStats{}, err 171 } 172 173 case cmp > 0: 174 err = buf.sendPatch(ctx, right) 175 if err != nil { 176 return MergeStats{}, err 177 } 178 updateStats(right, &stats) 179 180 right, err = r.Next(ctx) 181 if err == io.EOF { 182 err, rok = nil, false 183 } 184 if err != nil { 185 return MergeStats{}, err 186 } 187 188 case cmp == 0: 189 resolved, ok := cb(left, right) 190 if ok { 191 err = buf.sendPatch(ctx, resolved) 192 updateStats(right, &stats) 193 } 194 if err != nil { 195 return MergeStats{}, err 196 } 197 198 left, err = l.Next(ctx) 199 if err == io.EOF { 200 err, lok = nil, false 201 } 202 if err != nil { 203 return MergeStats{}, err 204 } 205 206 right, err = r.Next(ctx) 207 if err == io.EOF { 208 err, rok = nil, false 209 } 210 if err != nil { 211 return MergeStats{}, err 212 } 213 } 214 } 215 216 if lok { 217 // already in left 218 return stats, nil 219 } 220 221 for rok { 222 err = buf.sendPatch(ctx, right) 223 if err != nil { 224 return MergeStats{}, err 225 } 226 updateStats(right, &stats) 227 228 right, err = r.Next(ctx) 229 if err == io.EOF { 230 err, rok = nil, false 231 } 232 if err != nil { 233 return MergeStats{}, err 234 } 235 } 236 237 return stats, nil 238 } 239 240 func updateStats(right Diff, stats *MergeStats) { 241 switch right.Type { 242 case AddedDiff: 243 stats.Adds++ 244 case RemovedDiff: 245 stats.Removes++ 246 case ModifiedDiff: 247 stats.Modifications++ 248 } 249 }