github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/diff/diff.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // This file incorporates work covered by the following copyright and
    16  // permission notice:
    17  //
    18  // Copyright 2016 Attic Labs, Inc. All rights reserved.
    19  // Licensed under the Apache License, version 2.0:
    20  // http://www.apache.org/licenses/LICENSE-2.0
    21  
    22  package diff
    23  
    24  import (
    25  	"context"
    26  	"errors"
    27  	"sync/atomic"
    28  
    29  	"golang.org/x/sync/errgroup"
    30  
    31  	"github.com/dolthub/dolt/go/store/types"
    32  )
    33  
    34  type (
    35  	diffFunc     func(ctx context.Context, changeChan chan<- types.ValueChanged) error
    36  	pathPartFunc func(v types.Value) (types.PathPart, error)
    37  	valueFunc    func(k types.Value) (types.Value, error)
    38  )
    39  
    40  // Difference represents a "diff" between two Noms graphs.
    41  type Difference struct {
    42  	// Path to the Value that has changed
    43  	Path types.Path
    44  	// ChangeType indicates the type of diff: modified, added, deleted
    45  	ChangeType types.DiffChangeType
    46  	// OldValue is Value before the change, can be nil if Value was added
    47  	OldValue types.Value
    48  	// NewValue is Value after the change, can be nil if Value was removed
    49  	NewValue types.Value
    50  	// NewKeyValue is used for when elements are added to diffs with a
    51  	// non-primitive key. The new key must available when the map gets updated.
    52  	NewKeyValue types.Value
    53  	// KeyValue holds the key associated with a changed map value
    54  	KeyValue types.Value
    55  }
    56  
    57  func (dif Difference) IsEmpty() bool {
    58  	return dif.Path == nil && dif.OldValue == nil && dif.NewValue == nil
    59  }
    60  
    61  type ShouldDescFunc func(v1, v2 types.Value) bool
    62  
    63  // differ is used internally to hold information necessary for diffing two graphs.
    64  type differ struct {
    65  	// Channel used to send Difference objects back to caller
    66  	diffChan chan<- Difference
    67  	// Use LeftRight diff as opposed to TopDown
    68  	leftRight bool
    69  
    70  	shouldDescend ShouldDescFunc
    71  
    72  	eg         *errgroup.Group
    73  	asyncPanic *atomic.Value
    74  }
    75  
    76  // Diff traverses two graphs simultaneously looking for differences. It returns
    77  // two channels: a DiffReceiveChan that the caller can use to iterate over the
    78  // diffs in the graph and a StopSendChanel that a caller can use to signal the
    79  // Diff function to stop processing.
    80  // Diff returns the Differences in depth-first first order. A 'diff' is defined
    81  // as one of the following conditions:
    82  //   - a Value is Added or Removed from a node in the graph
    83  //   - the type of a Value has changed in the graph
    84  //   - a primitive (i.e. Bool, Float, String, Ref or Blob) Value has changed.
    85  //
    86  // A Difference is not returned when a non-primitive value has been modified. For
    87  // example, a struct field has been changed from one Value of type Employee to
    88  // another. Those modifications are accounted for by the Differences described
    89  // above at a lower point in the graph.
    90  //
    91  // If leftRight is true then the left-right diff is used for ordered sequences
    92  // - see Diff vs DiffLeftRight in Set and Map.
    93  //
    94  // Note: the function sends messages on diffChan and checks whether stopChan has
    95  // been closed to know if it needs to terminate diffing early. To function
    96  // properly it needs to be executed concurrently with code that reads values from
    97  // diffChan. The following is a typical invocation of Diff():
    98  //
    99  //	dChan := make(chan Difference)
   100  //	sChan := make(chan struct{})
   101  //	go func() {
   102  //	    d.Diff(s3, s4, dChan, sChan, leftRight)
   103  //	    close(dChan)
   104  //	}()
   105  //	for dif := range dChan {
   106  //	    <some code>
   107  //	}
   108  func Diff(ctx context.Context, v1, v2 types.Value, dChan chan<- Difference, leftRight bool, descFunc ShouldDescFunc) error {
   109  	f := func(ctx context.Context, d differ, v1, v2 types.Value) error {
   110  		return d.diff(ctx, nil, v1, v2)
   111  	}
   112  
   113  	return diff(ctx, f, v1, v2, dChan, leftRight, descFunc)
   114  }
   115  
   116  func DiffMapRange(ctx context.Context, m1, m2 types.Map, start types.Value, inRange types.ValueInRange, dChan chan<- Difference, leftRight bool, descFunc ShouldDescFunc) error {
   117  	f := func(ctx context.Context, d differ, v1, v2 types.Value) error {
   118  		return d.diffMapsInRange(ctx, nil, m1, m2, start, inRange)
   119  	}
   120  
   121  	return diff(ctx, f, m1, m2, dChan, leftRight, descFunc)
   122  }
   123  
   124  func diff(ctx context.Context,
   125  	f func(ctx context.Context, d differ, v1, v2 types.Value) error,
   126  	v1, v2 types.Value,
   127  	dChan chan<- Difference,
   128  	leftRight bool,
   129  	descFunc ShouldDescFunc) error {
   130  	if descFunc == nil {
   131  		descFunc = ShouldDescend
   132  	}
   133  
   134  	eg, ctx := errgroup.WithContext(ctx)
   135  	d := differ{
   136  		diffChan:      dChan,
   137  		leftRight:     leftRight,
   138  		shouldDescend: descFunc,
   139  
   140  		eg:         eg,
   141  		asyncPanic: new(atomic.Value),
   142  	}
   143  	if !v1.Equals(v2) {
   144  		if !d.shouldDescend(v1, v2) {
   145  			return d.sendDiff(ctx, Difference{Path: nil, ChangeType: types.DiffChangeModified, OldValue: v1, NewValue: v2})
   146  		} else {
   147  			d.GoCatchPanic(func() error {
   148  				return f(ctx, d, v1, v2)
   149  			})
   150  			return d.Wait()
   151  		}
   152  	}
   153  	return nil
   154  }
   155  
   156  func (d differ) diff(ctx context.Context, p types.Path, v1, v2 types.Value) error {
   157  	switch v1.Kind() {
   158  	case types.ListKind:
   159  		return d.diffLists(ctx, p, v1.(types.List), v2.(types.List))
   160  	case types.MapKind:
   161  		return d.diffMaps(ctx, p, v1.(types.Map), v2.(types.Map))
   162  	case types.SetKind:
   163  		return d.diffSets(ctx, p, v1.(types.Set), v2.(types.Set))
   164  	case types.StructKind:
   165  		return d.diffStructs(ctx, p, v1.(types.Struct), v2.(types.Struct))
   166  	default:
   167  		panic("Unrecognized type in diff function")
   168  	}
   169  }
   170  
   171  var AsyncPanicErr = errors.New("async panic")
   172  
   173  func (d differ) GoCatchPanic(f func() error) {
   174  	d.eg.Go(func() (err error) {
   175  		defer func() {
   176  			if r := recover(); r != nil {
   177  				d.asyncPanic.Store(r)
   178  				err = AsyncPanicErr
   179  			}
   180  		}()
   181  		return f()
   182  	})
   183  }
   184  
   185  func (d differ) Wait() error {
   186  	err := d.eg.Wait()
   187  	if p := d.asyncPanic.Load(); p != nil {
   188  		panic(p)
   189  	}
   190  	return err
   191  }
   192  
   193  func (d differ) diffLists(ctx context.Context, p types.Path, v1, v2 types.List) error {
   194  	spliceChan := make(chan types.Splice)
   195  
   196  	d.GoCatchPanic(func() error {
   197  		defer close(spliceChan)
   198  		return v2.Diff(ctx, v1, spliceChan)
   199  	})
   200  
   201  	for splice := range spliceChan {
   202  		if ctx.Err() != nil {
   203  			return ctx.Err()
   204  		}
   205  		if splice.SpRemoved == splice.SpAdded {
   206  			// Heuristic: list only has modifications.
   207  			for i := uint64(0); i < splice.SpRemoved; i++ {
   208  				lastEl, err := v1.Get(ctx, splice.SpAt+i)
   209  				if err != nil {
   210  					return err
   211  				}
   212  
   213  				newEl, err := v2.Get(ctx, splice.SpFrom+i)
   214  				if err != nil {
   215  					return err
   216  				}
   217  
   218  				if d.shouldDescend(lastEl, newEl) {
   219  					idx := types.Float(splice.SpAt + i)
   220  					err := d.diff(ctx, append(p, types.NewIndexPath(idx)), lastEl, newEl)
   221  					if err != nil {
   222  						return err
   223  					}
   224  				} else {
   225  					p1 := p.Append(types.NewIndexPath(types.Float(splice.SpAt + i)))
   226  					oldVal, err := v1.Get(ctx, splice.SpAt+i)
   227  					if err != nil {
   228  						return err
   229  					}
   230  
   231  					newVal, err := v2.Get(ctx, splice.SpFrom+i)
   232  					if err != nil {
   233  						return err
   234  					}
   235  
   236  					dif := Difference{Path: p1, ChangeType: types.DiffChangeModified, OldValue: oldVal, NewValue: newVal}
   237  					err = d.sendDiff(ctx, dif)
   238  					if err != nil {
   239  						return err
   240  					}
   241  				}
   242  			}
   243  			continue
   244  		}
   245  
   246  		// Heuristic: list only has additions/removals.
   247  		for i := uint64(0); i < splice.SpRemoved; i++ {
   248  			p1 := p.Append(types.NewIndexPath(types.Float(splice.SpAt + i)))
   249  			oldVal, err := v1.Get(ctx, splice.SpAt+i)
   250  			if err != nil {
   251  				return err
   252  			}
   253  
   254  			dif := Difference{Path: p1, ChangeType: types.DiffChangeRemoved, OldValue: oldVal, NewValue: nil}
   255  			err = d.sendDiff(ctx, dif)
   256  			if err != nil {
   257  				return err
   258  			}
   259  		}
   260  		for i := uint64(0); i < splice.SpAdded; i++ {
   261  			p1 := p.Append(types.NewIndexPath(types.Float(splice.SpFrom + i)))
   262  			newVal, err := v2.Get(ctx, splice.SpFrom+i)
   263  			if err != nil {
   264  				return err
   265  			}
   266  
   267  			dif := Difference{Path: p1, ChangeType: types.DiffChangeAdded, OldValue: nil, NewValue: newVal}
   268  			err = d.sendDiff(ctx, dif)
   269  			if err != nil {
   270  				return err
   271  			}
   272  		}
   273  	}
   274  
   275  	return nil
   276  }
   277  
   278  func (d differ) diffMaps(ctx context.Context, p types.Path, v1, v2 types.Map) error {
   279  	trueFunc := func(ctx context.Context, value types.Value) (bool, bool, error) {
   280  		return true, false, nil
   281  	}
   282  
   283  	return d.diffMapsInRange(ctx, p, v1, v2, nil, trueFunc)
   284  }
   285  
   286  func (d differ) diffMapsInRange(ctx context.Context, p types.Path, v1, v2 types.Map, start types.Value, inRange types.ValueInRange) error {
   287  	return d.diffOrdered(ctx, p,
   288  		func(v types.Value) (types.PathPart, error) {
   289  			if types.ValueCanBePathIndex(v) {
   290  				return types.NewIndexPath(v), nil
   291  			} else {
   292  				h, err := v.Hash(v1.Format())
   293  
   294  				if err != nil {
   295  					return nil, err
   296  				}
   297  
   298  				return types.NewHashIndexPath(h), nil
   299  			}
   300  		},
   301  		func(ctx context.Context, cc chan<- types.ValueChanged) error {
   302  			if d.leftRight {
   303  				return v2.DiffLeftRightInRange(ctx, v1, start, inRange, cc)
   304  			} else {
   305  				if start != nil {
   306  					panic("not implemented")
   307  				}
   308  
   309  				return v2.Diff(ctx, v1, cc)
   310  			}
   311  		},
   312  		func(k types.Value) (types.Value, error) {
   313  			return k, nil
   314  		},
   315  		func(k types.Value) (types.Value, error) {
   316  			v, _, err := v1.MaybeGet(ctx, k)
   317  			return v, err
   318  		},
   319  		func(k types.Value) (types.Value, error) {
   320  			v, _, err := v2.MaybeGet(ctx, k)
   321  			return v, err
   322  		},
   323  	)
   324  }
   325  
   326  func (d differ) diffStructs(ctx context.Context, p types.Path, v1, v2 types.Struct) error {
   327  	str := func(v types.Value) string {
   328  		return string(v.(types.String))
   329  	}
   330  	return d.diffOrdered(ctx, p,
   331  		func(v types.Value) (types.PathPart, error) {
   332  			return types.NewFieldPath(str(v)), nil
   333  		},
   334  		func(ctx context.Context, cc chan<- types.ValueChanged) error {
   335  			return v2.Diff(ctx, v1, cc)
   336  		},
   337  		func(k types.Value) (types.Value, error) { return k, nil },
   338  		func(k types.Value) (types.Value, error) {
   339  			val, _, err := v1.MaybeGet(str(k))
   340  			return val, err
   341  		},
   342  		func(k types.Value) (types.Value, error) {
   343  			val, _, err := v2.MaybeGet(str(k))
   344  			return val, err
   345  		},
   346  	)
   347  }
   348  
   349  func (d differ) diffSets(ctx context.Context, p types.Path, v1, v2 types.Set) error {
   350  	return d.diffOrdered(ctx, p,
   351  		func(v types.Value) (types.PathPart, error) {
   352  			if types.ValueCanBePathIndex(v) {
   353  				return types.NewIndexPath(v), nil
   354  			}
   355  
   356  			h, err := v.Hash(v1.Format())
   357  
   358  			if err != nil {
   359  				return nil, err
   360  			}
   361  
   362  			return types.NewHashIndexPath(h), nil
   363  		},
   364  		func(ctx context.Context, cc chan<- types.ValueChanged) error {
   365  			if d.leftRight {
   366  				return v2.DiffLeftRight(ctx, v1, cc)
   367  			} else {
   368  				return v2.Diff(ctx, v1, cc)
   369  			}
   370  		},
   371  		func(k types.Value) (types.Value, error) { return k, nil },
   372  		func(k types.Value) (types.Value, error) { return k, nil },
   373  		func(k types.Value) (types.Value, error) { return k, nil },
   374  	)
   375  }
   376  
   377  func (d differ) diffOrdered(ctx context.Context, p types.Path, ppf pathPartFunc, df diffFunc, kf, v1, v2 valueFunc) error {
   378  	changeChan := make(chan types.ValueChanged)
   379  
   380  	d.GoCatchPanic(func() error {
   381  		defer close(changeChan)
   382  		return df(ctx, changeChan)
   383  	})
   384  
   385  	for change := range changeChan {
   386  		if ctx.Err() != nil {
   387  			return ctx.Err()
   388  		}
   389  
   390  		k, err := kf(change.Key)
   391  		if err != nil {
   392  			return err
   393  		}
   394  
   395  		ppfRes, err := ppf(k)
   396  		if err != nil {
   397  			return err
   398  		}
   399  
   400  		p1 := p.Append(ppfRes)
   401  
   402  		switch change.ChangeType {
   403  		case types.DiffChangeAdded:
   404  			newVal, err := v2(change.Key)
   405  			if err != nil {
   406  				return err
   407  			}
   408  
   409  			dif := Difference{Path: p1, ChangeType: types.DiffChangeAdded, OldValue: nil, NewValue: newVal, NewKeyValue: k, KeyValue: change.Key}
   410  			err = d.sendDiff(ctx, dif)
   411  			if err != nil {
   412  				return err
   413  			}
   414  		case types.DiffChangeRemoved:
   415  			oldVal, err := v1(change.Key)
   416  			if err != nil {
   417  				return err
   418  			}
   419  
   420  			dif := Difference{Path: p1, ChangeType: types.DiffChangeRemoved, OldValue: oldVal, KeyValue: change.Key}
   421  			err = d.sendDiff(ctx, dif)
   422  			if err != nil {
   423  				return err
   424  			}
   425  		case types.DiffChangeModified:
   426  			c1, err := v1(change.Key)
   427  			if err != nil {
   428  				return err
   429  			}
   430  
   431  			c2, err := v2(change.Key)
   432  			if err != nil {
   433  				return err
   434  			}
   435  
   436  			if d.shouldDescend(c1, c2) {
   437  				err = d.diff(ctx, p1, c1, c2)
   438  				if err != nil {
   439  					return err
   440  				}
   441  			} else {
   442  				dif := Difference{Path: p1, ChangeType: types.DiffChangeModified, OldValue: c1, NewValue: c2, KeyValue: change.Key}
   443  				err = d.sendDiff(ctx, dif)
   444  				if err != nil {
   445  					return err
   446  				}
   447  			}
   448  		default:
   449  			panic("unknown change type")
   450  		}
   451  	}
   452  
   453  	return nil
   454  }
   455  
   456  // shouldDescend returns true, if Value is not primitive or is a Ref.
   457  func ShouldDescend(v1, v2 types.Value) bool {
   458  	kind := v1.Kind()
   459  	return !types.IsPrimitiveKind(kind) && kind == v2.Kind() && kind != types.RefKind && kind != types.TupleKind
   460  }
   461  
   462  func (d differ) sendDiff(ctx context.Context, dif Difference) error {
   463  	select {
   464  	case <-ctx.Done():
   465  		return ctx.Err()
   466  	case d.diffChan <- dif:
   467  		return nil
   468  	}
   469  }