github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/rebase/rebase.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rebase
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  
    21  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    22  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdocs"
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/ref"
    25  	"github.com/dolthub/dolt/go/store/hash"
    26  )
    27  
    28  type visitedSet map[hash.Hash]*doltdb.Commit
    29  
    30  type NeedsRebaseFn func(ctx context.Context, cm *doltdb.Commit) (bool, error)
    31  
    32  // EntireHistory returns a |NeedsRebaseFn| that rebases the entire commit history.
    33  func EntireHistory() NeedsRebaseFn {
    34  	return func(_ context.Context, cm *doltdb.Commit) (bool, error) {
    35  		n, err := cm.NumParents()
    36  		return n != 0, err
    37  	}
    38  }
    39  
    40  // StopAtCommit returns a |NeedsRebaseFn| that rebases the commit history until
    41  // |stopCommit| is reached. It will error if |stopCommit| is not reached.
    42  func StopAtCommit(stopCommit *doltdb.Commit) NeedsRebaseFn {
    43  	return func(ctx context.Context, cm *doltdb.Commit) (bool, error) {
    44  		h, err := cm.HashOf()
    45  		if err != nil {
    46  			return false, err
    47  		}
    48  
    49  		sh, err := stopCommit.HashOf()
    50  		if err != nil {
    51  			return false, err
    52  		}
    53  
    54  		if h.Equal(sh) {
    55  			return false, nil
    56  		}
    57  
    58  		n, err := cm.NumParents()
    59  		if err != nil {
    60  			return false, err
    61  		}
    62  		if n == 0 {
    63  			return false, fmt.Errorf("commit %s is missing from the commit history of at least one rebase head", sh)
    64  		}
    65  
    66  		return true, nil
    67  	}
    68  }
    69  
    70  type ReplayRootFn func(ctx context.Context, root, parentRoot, rebasedParentRoot *doltdb.RootValue) (rebaseRoot *doltdb.RootValue, err error)
    71  
    72  type ReplayCommitFn func(ctx context.Context, commit, parent, rebasedParent *doltdb.Commit) (rebaseRoot *doltdb.RootValue, err error)
    73  
    74  // wrapReplayRootFn converts a |ReplayRootFn| to a |ReplayCommitFn|
    75  func wrapReplayRootFn(fn ReplayRootFn) ReplayCommitFn {
    76  	return func(ctx context.Context, commit, parent, rebasedParent *doltdb.Commit) (rebaseRoot *doltdb.RootValue, err error) {
    77  		root, err := commit.GetRootValue()
    78  		if err != nil {
    79  			return nil, err
    80  		}
    81  
    82  		parentRoot, err := parent.GetRootValue()
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  
    87  		rebasedParentRoot, err := rebasedParent.GetRootValue()
    88  		if err != nil {
    89  			return nil, err
    90  		}
    91  
    92  		return fn(ctx, root, parentRoot, rebasedParentRoot)
    93  	}
    94  }
    95  
    96  // AllBranches rewrites the history of all branches in the repo using the |replay| function.
    97  func AllBranches(ctx context.Context, dEnv *env.DoltEnv, replay ReplayCommitFn, nerf NeedsRebaseFn) error {
    98  	branches, err := dEnv.DoltDB.GetBranches(ctx)
    99  	if err != nil {
   100  		return err
   101  	}
   102  
   103  	return rebaseRefs(ctx, dEnv.DbData(), replay, nerf, branches...)
   104  }
   105  
   106  // CurrentBranch rewrites the history of the current branch using the |replay| function.
   107  func CurrentBranch(ctx context.Context, dEnv *env.DoltEnv, replay ReplayCommitFn, nerf NeedsRebaseFn) error {
   108  	return rebaseRefs(ctx, dEnv.DbData(), replay, nerf, dEnv.RepoState.CWBHeadRef())
   109  }
   110  
   111  // AllBranchesByRoots rewrites the history of all branches in the repo using the |replay| function.
   112  func AllBranchesByRoots(ctx context.Context, dEnv *env.DoltEnv, replay ReplayRootFn, nerf NeedsRebaseFn) error {
   113  	branches, err := dEnv.DoltDB.GetBranches(ctx)
   114  	if err != nil {
   115  		return err
   116  	}
   117  
   118  	replayCommit := wrapReplayRootFn(replay)
   119  	return rebaseRefs(ctx, dEnv.DbData(), replayCommit, nerf, branches...)
   120  }
   121  
   122  // CurrentBranchByRoot rewrites the history of the current branch using the |replay| function.
   123  func CurrentBranchByRoot(ctx context.Context, dEnv *env.DoltEnv, replay ReplayRootFn, nerf NeedsRebaseFn) error {
   124  	replayCommit := wrapReplayRootFn(replay)
   125  	return rebaseRefs(ctx, dEnv.DbData(), replayCommit, nerf, dEnv.RepoState.CWBHeadRef())
   126  }
   127  
   128  func rebaseRefs(ctx context.Context, dbData env.DbData, replay ReplayCommitFn, nerf NeedsRebaseFn, refs ...ref.DoltRef) error {
   129  	ddb := dbData.Ddb
   130  	rsr := dbData.Rsr
   131  	rsw := dbData.Rsw
   132  	drw := dbData.Drw
   133  
   134  	cwbRef := rsr.CWBHeadRef()
   135  	dd, err := drw.GetDocsOnDisk()
   136  	if err != nil {
   137  		return err
   138  	}
   139  
   140  	heads := make([]*doltdb.Commit, len(refs))
   141  	for i, dRef := range refs {
   142  		heads[i], err = ddb.ResolveCommitRef(ctx, dRef)
   143  		if err != nil {
   144  			return err
   145  		}
   146  	}
   147  
   148  	newHeads, err := rebase(ctx, ddb, replay, nerf, heads...)
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	for i, dRef := range refs {
   154  
   155  		switch dRef.(type) {
   156  		case ref.BranchRef:
   157  			err = ddb.DeleteBranch(ctx, dRef)
   158  			if err != nil {
   159  				return err
   160  			}
   161  			err = ddb.NewBranchAtCommit(ctx, dRef, newHeads[i])
   162  
   163  		default:
   164  			return fmt.Errorf("cannot rebase ref: %s", ref.String(dRef))
   165  		}
   166  
   167  		if err != nil {
   168  			return err
   169  		}
   170  	}
   171  
   172  	cm, err := ddb.ResolveCommitRef(ctx, cwbRef)
   173  	if err != nil {
   174  		return err
   175  	}
   176  
   177  	r, err := cm.GetRootValue()
   178  	if err != nil {
   179  		return err
   180  	}
   181  
   182  	_, err = doltdocs.UpdateRootWithDocs(ctx, r, dd)
   183  	if err != nil {
   184  		return err
   185  	}
   186  
   187  	_, err = env.UpdateStagedRoot(ctx, ddb, rsw, r)
   188  	if err != nil {
   189  		return err
   190  	}
   191  
   192  	_, err = env.UpdateWorkingRoot(ctx, ddb, rsw, r)
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	return err
   198  }
   199  
   200  func rebase(ctx context.Context, ddb *doltdb.DoltDB, replay ReplayCommitFn, nerf NeedsRebaseFn, origins ...*doltdb.Commit) ([]*doltdb.Commit, error) {
   201  	var rebasedCommits []*doltdb.Commit
   202  	vs := make(visitedSet)
   203  	for _, cm := range origins {
   204  		rc, err := rebaseRecursive(ctx, ddb, replay, nerf, vs, cm)
   205  
   206  		if err != nil {
   207  			return nil, err
   208  		}
   209  
   210  		rebasedCommits = append(rebasedCommits, rc)
   211  	}
   212  
   213  	return rebasedCommits, nil
   214  }
   215  
   216  func rebaseRecursive(ctx context.Context, ddb *doltdb.DoltDB, replay ReplayCommitFn, nerf NeedsRebaseFn, vs visitedSet, commit *doltdb.Commit) (*doltdb.Commit, error) {
   217  	commitHash, err := commit.HashOf()
   218  	if err != nil {
   219  		return nil, err
   220  	}
   221  	visitedCommit, found := vs[commitHash]
   222  	if found {
   223  		// base case: reached previously rebased node
   224  		return visitedCommit, nil
   225  	}
   226  
   227  	needToRebase, err := nerf(ctx, commit)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  	if !needToRebase {
   232  		// base case: reached bottom of DFS,
   233  		return commit, nil
   234  	}
   235  
   236  	allParents, err := ddb.ResolveAllParents(ctx, commit)
   237  	if err != nil {
   238  		return nil, err
   239  	}
   240  
   241  	if len(allParents) < 1 {
   242  		panic(fmt.Sprintf("commit: %s has no parents", commitHash.String()))
   243  	}
   244  
   245  	var allRebasedParents []*doltdb.Commit
   246  	for _, p := range allParents {
   247  		rp, err := rebaseRecursive(ctx, ddb, replay, nerf, vs, p)
   248  
   249  		if err != nil {
   250  			return nil, err
   251  		}
   252  
   253  		allRebasedParents = append(allRebasedParents, rp)
   254  	}
   255  
   256  	rebasedRoot, err := replay(ctx, commit, allParents[0], allRebasedParents[0])
   257  	if err != nil {
   258  		return nil, err
   259  	}
   260  
   261  	valueHash, err := ddb.WriteRootValue(ctx, rebasedRoot)
   262  	if err != nil {
   263  		return nil, err
   264  	}
   265  
   266  	oldMeta, err := commit.GetCommitMeta()
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  
   271  	rebasedCommit, err := ddb.CommitDanglingWithParentCommits(ctx, valueHash, allRebasedParents, oldMeta)
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  
   276  	vs[commitHash] = rebasedCommit
   277  	return rebasedCommit, nil
   278  }