github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/diff/async_differ.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package diff
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"time"
    21  
    22  	"golang.org/x/sync/errgroup"
    23  
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    26  	"github.com/dolthub/dolt/go/libraries/utils/async"
    27  	"github.com/dolthub/dolt/go/store/diff"
    28  	"github.com/dolthub/dolt/go/store/types"
    29  )
    30  
    31  type RowDiffer interface {
    32  	// Start starts the RowDiffer.
    33  	Start(ctx context.Context, from, to types.Map)
    34  
    35  	// GetDiffs returns the requested number of diff.Differences, or times out.
    36  	GetDiffs(numDiffs int, timeout time.Duration) ([]*diff.Difference, bool, error)
    37  
    38  	// Close closes the RowDiffer.
    39  	Close() error
    40  }
    41  
    42  func NewRowDiffer(ctx context.Context, fromSch, toSch schema.Schema, buf int) RowDiffer {
    43  	ad := NewAsyncDiffer(buf)
    44  
    45  	// assumes no PK changes
    46  	// mixed diffing of keyless and pk tables no supported
    47  	if schema.IsKeyless(fromSch) || schema.IsKeyless(toSch) {
    48  		return &keylessDiffer{AsyncDiffer: ad}
    49  	}
    50  
    51  	return ad
    52  }
    53  
    54  // todo: make package private
    55  type AsyncDiffer struct {
    56  	diffChan   chan diff.Difference
    57  	bufferSize int
    58  
    59  	eg       *errgroup.Group
    60  	egCtx    context.Context
    61  	egCancel func()
    62  
    63  	diffStats map[types.DiffChangeType]uint64
    64  }
    65  
    66  var _ RowDiffer = &AsyncDiffer{}
    67  
    68  // todo: make package private once dolthub is migrated
    69  func NewAsyncDiffer(bufferedDiffs int) *AsyncDiffer {
    70  	return &AsyncDiffer{
    71  		diffChan:   make(chan diff.Difference, bufferedDiffs),
    72  		bufferSize: bufferedDiffs,
    73  		egCtx:      context.Background(),
    74  		egCancel:   func() {},
    75  		diffStats:  make(map[types.DiffChangeType]uint64),
    76  	}
    77  }
    78  
    79  func tableDontDescendLists(v1, v2 types.Value) bool {
    80  	kind := v1.Kind()
    81  	return !types.IsPrimitiveKind(kind) && kind != types.TupleKind && kind == v2.Kind() && kind != types.RefKind
    82  }
    83  
    84  func (ad *AsyncDiffer) Start(ctx context.Context, from, to types.Map) {
    85  	ad.start(ctx, func(ctx context.Context) error {
    86  		return diff.Diff(ctx, from, to, ad.diffChan, true, tableDontDescendLists)
    87  	})
    88  }
    89  
    90  func (ad *AsyncDiffer) StartWithRange(ctx context.Context, from, to types.Map, start types.Value, inRange types.ValueInRange) {
    91  	ad.start(ctx, func(ctx context.Context) error {
    92  		return diff.DiffMapRange(ctx, from, to, start, inRange, ad.diffChan, true, tableDontDescendLists)
    93  	})
    94  }
    95  
    96  func (ad *AsyncDiffer) start(ctx context.Context, diffFunc func(ctx context.Context) error) {
    97  	ad.eg, ad.egCtx = errgroup.WithContext(ctx)
    98  	ad.egCancel = async.GoWithCancel(ad.egCtx, ad.eg, func(ctx context.Context) (err error) {
    99  		defer close(ad.diffChan)
   100  		defer func() {
   101  			if r := recover(); r != nil {
   102  				err = fmt.Errorf("panic in diff.Diff: %v", r)
   103  			}
   104  		}()
   105  		return diffFunc(ctx)
   106  	})
   107  }
   108  
   109  func (ad *AsyncDiffer) Close() error {
   110  	ad.egCancel()
   111  	return ad.eg.Wait()
   112  }
   113  
   114  func (ad *AsyncDiffer) getDiffs(numDiffs int, timeoutChan <-chan time.Time, pred diffPredicate) ([]*diff.Difference, bool, error) {
   115  	diffs := make([]*diff.Difference, 0, ad.bufferSize)
   116  	for {
   117  		select {
   118  		case d, more := <-ad.diffChan:
   119  			if more {
   120  				if pred(&d) {
   121  					ad.diffStats[d.ChangeType]++
   122  					diffs = append(diffs, &d)
   123  				}
   124  				if numDiffs != 0 && numDiffs == len(diffs) {
   125  					return diffs, true, nil
   126  				}
   127  			} else {
   128  				return diffs, false, ad.eg.Wait()
   129  			}
   130  		case <-timeoutChan:
   131  			return diffs, true, nil
   132  		case <-ad.egCtx.Done():
   133  			return nil, false, ad.eg.Wait()
   134  		}
   135  	}
   136  }
   137  
   138  var forever <-chan time.Time = make(chan time.Time)
   139  
   140  type diffPredicate func(*diff.Difference) bool
   141  
   142  var alwaysTruePredicate diffPredicate = func(*diff.Difference) bool {
   143  	return true
   144  }
   145  
   146  func hasChangeTypePredicate(changeType types.DiffChangeType) diffPredicate {
   147  	return func(d *diff.Difference) bool {
   148  		return d.ChangeType == changeType
   149  	}
   150  }
   151  
   152  func (ad *AsyncDiffer) GetDiffs(numDiffs int, timeout time.Duration) ([]*diff.Difference, bool, error) {
   153  	if timeout < 0 {
   154  		return ad.GetDiffsWithoutTimeout(numDiffs)
   155  	}
   156  	return ad.getDiffs(numDiffs, time.After(timeout), alwaysTruePredicate)
   157  }
   158  
   159  func (ad *AsyncDiffer) GetDiffsWithFilter(numDiffs int, timeout time.Duration, filterByChangeType types.DiffChangeType) ([]*diff.Difference, bool, error) {
   160  	if timeout < 0 {
   161  		return ad.GetDiffsWithoutTimeoutWithFilter(numDiffs, filterByChangeType)
   162  	}
   163  	return ad.getDiffs(numDiffs, time.After(timeout), hasChangeTypePredicate(filterByChangeType))
   164  }
   165  
   166  func (ad *AsyncDiffer) GetDiffsWithoutTimeoutWithFilter(numDiffs int, filterByChangeType types.DiffChangeType) ([]*diff.Difference, bool, error) {
   167  	return ad.getDiffs(numDiffs, forever, hasChangeTypePredicate(filterByChangeType))
   168  }
   169  
   170  func (ad *AsyncDiffer) GetDiffsWithoutTimeout(numDiffs int) ([]*diff.Difference, bool, error) {
   171  	return ad.getDiffs(numDiffs, forever, alwaysTruePredicate)
   172  }
   173  
   174  type keylessDiffer struct {
   175  	*AsyncDiffer
   176  
   177  	df         diff.Difference
   178  	copiesLeft uint64
   179  }
   180  
   181  var _ RowDiffer = &keylessDiffer{}
   182  
   183  func (kd *keylessDiffer) GetDiffs(numDiffs int, timeout time.Duration) (diffs []*diff.Difference, more bool, err error) {
   184  	timeoutChan := time.After(timeout)
   185  	diffs = make([]*diff.Difference, numDiffs)
   186  	idx := 0
   187  
   188  	for {
   189  		// first populate |diffs| with copies of |kd.df|
   190  		for (idx < numDiffs) && (kd.copiesLeft > 0) {
   191  			diffs[idx] = &kd.df
   192  
   193  			idx++
   194  			kd.copiesLeft--
   195  		}
   196  		if idx == numDiffs {
   197  			return diffs, true, nil
   198  		}
   199  
   200  		// then get another Difference
   201  		var d diff.Difference
   202  		select {
   203  		case <-timeoutChan:
   204  			return diffs, true, nil
   205  
   206  		case <-kd.egCtx.Done():
   207  			return nil, false, kd.eg.Wait()
   208  
   209  		case d, more = <-kd.diffChan:
   210  			if !more {
   211  				return diffs[:idx], more, nil
   212  			}
   213  
   214  			kd.df, kd.copiesLeft, err = convertDiff(d)
   215  			if err != nil {
   216  				return nil, false, err
   217  			}
   218  		}
   219  	}
   220  
   221  }
   222  
   223  // convertDiff reports the cardinality of a change,
   224  // and converts updates to adds or deletes
   225  func convertDiff(df diff.Difference) (diff.Difference, uint64, error) {
   226  	var oldCard uint64
   227  	if df.OldValue != nil {
   228  		v, err := df.OldValue.(types.Tuple).Get(row.KeylessCardinalityValIdx)
   229  		if err != nil {
   230  			return df, 0, err
   231  		}
   232  		oldCard = uint64(v.(types.Uint))
   233  	}
   234  
   235  	var newCard uint64
   236  	if df.NewValue != nil {
   237  		v, err := df.NewValue.(types.Tuple).Get(row.KeylessCardinalityValIdx)
   238  		if err != nil {
   239  			return df, 0, err
   240  		}
   241  		newCard = uint64(v.(types.Uint))
   242  	}
   243  
   244  	switch df.ChangeType {
   245  	case types.DiffChangeRemoved:
   246  		return df, oldCard, nil
   247  
   248  	case types.DiffChangeAdded:
   249  		return df, newCard, nil
   250  
   251  	case types.DiffChangeModified:
   252  		delta := int64(newCard) - int64(oldCard)
   253  		if delta > 0 {
   254  			df.ChangeType = types.DiffChangeAdded
   255  			df.OldValue = nil
   256  			return df, uint64(delta), nil
   257  		} else if delta < 0 {
   258  			df.ChangeType = types.DiffChangeRemoved
   259  			df.NewValue = nil
   260  			return df, uint64(-delta), nil
   261  		} else {
   262  			panic(fmt.Sprintf("diff with delta = 0 for key: %s", df.KeyValue.HumanReadableString()))
   263  		}
   264  	default:
   265  		return df, 0, fmt.Errorf("unexpected DiffChange type %d", df.ChangeType)
   266  	}
   267  }