github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_merge.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dprocedures
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	gmstypes "github.com/dolthub/go-mysql-server/sql/types"
    24  	goerrors "gopkg.in/src-d/go-errors.v1"
    25  
    26  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/dconfig"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/merge"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    34  	"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
    35  	"github.com/dolthub/dolt/go/libraries/utils/argparser"
    36  	"github.com/dolthub/dolt/go/store/hash"
    37  )
    38  
    39  const DoltMergeWarningCode int = 1105 // Since this our own custom warning we'll use 1105, the code for an unknown error
    40  
    41  const (
    42  	noConflictsOrViolations  int = 0
    43  	hasConflictsOrViolations int = 1
    44  )
    45  
    46  const (
    47  	threeWayMerge    = 0
    48  	fastForwardMerge = 1
    49  )
    50  
    51  // For callers of dolt_merge(), the index of the FastForward column is needed to print results. If the schema of
    52  // the result changes, this will need to be updated.
    53  const MergeProcFFIndex = 1
    54  
    55  var ErrUncommittedChanges = goerrors.NewKind("cannot merge with uncommitted changes")
    56  
    57  var doltMergeSchema = []*sql.Column{
    58  	{
    59  		Name:     "hash",
    60  		Type:     gmstypes.LongText,
    61  		Nullable: true,
    62  	},
    63  	{
    64  		Name:     "fast_forward",
    65  		Type:     gmstypes.Int64,
    66  		Nullable: false,
    67  	},
    68  	{
    69  		Name:     "conflicts",
    70  		Type:     gmstypes.Int64,
    71  		Nullable: false,
    72  	},
    73  	{
    74  		Name:     "message",
    75  		Type:     gmstypes.LongText,
    76  		Nullable: true,
    77  	},
    78  }
    79  
    80  // doltMerge is the stored procedure version for the CLI command `dolt merge`.
    81  func doltMerge(ctx *sql.Context, args ...string) (sql.RowIter, error) {
    82  	commitHash, hasConflicts, ff, message, err := doDoltMerge(ctx, args)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	if message == "" {
    87  		return rowToIter(commitHash, int64(ff), int64(hasConflicts), nil), nil
    88  	} else {
    89  		return rowToIter(commitHash, int64(ff), int64(hasConflicts), message), nil
    90  	}
    91  }
    92  
    93  // doDoltMerge returns has_conflicts and fast_forward status
    94  //
    95  // There are two ways to communicate results procedure to the user:
    96  //  1. return a non-nil error. The error message will be given to the user in their context.
    97  //  2. return a non-empty message to the user. This is needed in non-error cases where the user needs to be informed
    98  //     of something that happened during the merge. This will be added to the message column of the result.
    99  func doDoltMerge(ctx *sql.Context, args []string) (string, int, int, string, error) {
   100  	dbName := ctx.GetCurrentDatabase()
   101  
   102  	if len(dbName) == 0 {
   103  		return "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("Empty database name.")
   104  	}
   105  	if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
   106  		return "", noConflictsOrViolations, threeWayMerge, "", err
   107  	}
   108  
   109  	sess := dsess.DSessFromSess(ctx.Session)
   110  
   111  	apr, err := cli.CreateMergeArgParser().Parse(args)
   112  	if err != nil {
   113  		return "", noConflictsOrViolations, threeWayMerge, "", err
   114  	}
   115  
   116  	if len(args) == 0 {
   117  		return "", noConflictsOrViolations, threeWayMerge, "", errors.New("error: Please specify a branch to merge")
   118  	}
   119  
   120  	if apr.ContainsAll(cli.SquashParam, cli.NoFFParam) {
   121  		return "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together.\n", cli.SquashParam, cli.NoFFParam)
   122  	}
   123  
   124  	ws, err := sess.WorkingSet(ctx, dbName)
   125  	if err != nil {
   126  		return "", noConflictsOrViolations, threeWayMerge, "", err
   127  	}
   128  	roots, ok := sess.GetRoots(ctx, dbName)
   129  	if !ok {
   130  		return "", noConflictsOrViolations, threeWayMerge, "", sql.ErrDatabaseNotFound.New(dbName)
   131  	}
   132  
   133  	if apr.Contains(cli.AbortParam) {
   134  		if !ws.MergeActive() {
   135  			return "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("fatal: There is no merge to abort")
   136  		}
   137  
   138  		ws, err = merge.AbortMerge(ctx, ws, roots)
   139  		if err != nil {
   140  			return "", noConflictsOrViolations, threeWayMerge, "", err
   141  		}
   142  
   143  		err := sess.SetWorkingSet(ctx, dbName, ws)
   144  		if err != nil {
   145  			return "", noConflictsOrViolations, threeWayMerge, "", err
   146  		}
   147  
   148  		err = sess.CommitWorkingSet(ctx, dbName, sess.GetTransaction())
   149  		if err != nil {
   150  			return "", noConflictsOrViolations, threeWayMerge, "", err
   151  		}
   152  
   153  		return "", noConflictsOrViolations, threeWayMerge, "merge aborted", nil
   154  	}
   155  
   156  	branchName := apr.Arg(0)
   157  
   158  	mergeSpec, err := createMergeSpec(ctx, sess, dbName, apr, branchName)
   159  	if err != nil {
   160  		return "", noConflictsOrViolations, threeWayMerge, "", err
   161  	}
   162  
   163  	dbData, ok := sess.GetDbData(ctx, dbName)
   164  	if !ok {
   165  		return "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("Could not load database %s", dbName)
   166  	}
   167  
   168  	headRef, err := dbData.Rsr.CWBHeadRef()
   169  	if err != nil {
   170  		return "", noConflictsOrViolations, threeWayMerge, "", err
   171  	}
   172  	msg := fmt.Sprintf("Merge branch '%s' into %s", branchName, headRef.GetPath())
   173  	if userMsg, mOk := apr.GetValue(cli.MessageArg); mOk {
   174  		msg = userMsg
   175  	}
   176  
   177  	ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg)
   178  	if err != nil {
   179  		return commit, conflicts, fastForward, "", err
   180  	}
   181  	if conflicts != 0 {
   182  		return commit, conflicts, fastForward, "conflicts found", nil
   183  	}
   184  
   185  	return commit, conflicts, fastForward, message, nil
   186  }
   187  
   188  // performMerge encapsulates server merge logic, switching between
   189  // fast-forward, no fast-forward, merge commit, and merging into working set.
   190  // Returns a new WorkingSet, whether there were merge conflicts, and whether a
   191  // fast-forward was performed. This commits the working set if merge is successful and
   192  // 'no-commit' flag is not defined.
   193  // TODO FF merging commit with constraint violations requires `constraint verify`
   194  func performMerge(
   195  	ctx *sql.Context,
   196  	sess *dsess.DoltSession,
   197  	ws *doltdb.WorkingSet,
   198  	dbName string,
   199  	spec *merge.MergeSpec,
   200  	noCommit bool,
   201  	msg string,
   202  ) (*doltdb.WorkingSet, string, int, int, string, error) {
   203  	// todo: allow merges even when an existing merge is uncommitted
   204  	if ws.MergeActive() {
   205  		return ws, "", noConflictsOrViolations, threeWayMerge, "", doltdb.ErrMergeActive
   206  	}
   207  
   208  	if len(spec.StompedTblNames) != 0 {
   209  		return ws, "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("error: local changes would be stomped by merge:\n\t%s\n Please commit your changes before you merge.", strings.Join(spec.StompedTblNames, "\n\t"))
   210  	}
   211  
   212  	dbData, ok := sess.GetDbData(ctx, dbName)
   213  	if !ok {
   214  		return ws, "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("failed to get dbData")
   215  	}
   216  
   217  	canFF, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC)
   218  	if err != nil {
   219  		switch err {
   220  		case doltdb.ErrIsAhead, doltdb.ErrUpToDate:
   221  			ctx.Warn(DoltMergeWarningCode, err.Error())
   222  			return ws, "", noConflictsOrViolations, threeWayMerge, err.Error(), nil
   223  		default:
   224  			return ws, "", noConflictsOrViolations, threeWayMerge, "", err
   225  		}
   226  	}
   227  
   228  	if canFF {
   229  		if spec.NoFF {
   230  			var commit *doltdb.Commit
   231  			ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit)
   232  			if err == doltdb.ErrUnresolvedConflictsOrViolations {
   233  				// if there are unresolved conflicts, write the resulting working set back to the session and return an
   234  				// error message
   235  				wsErr := sess.SetWorkingSet(ctx, dbName, ws)
   236  				if wsErr != nil {
   237  					return ws, "", hasConflictsOrViolations, threeWayMerge, "", wsErr
   238  				}
   239  				ctx.Warn(DoltMergeWarningCode, err.Error())
   240  				return ws, "", hasConflictsOrViolations, threeWayMerge, "", err
   241  			} else if err != nil {
   242  				return ws, "", noConflictsOrViolations, threeWayMerge, "", err
   243  			}
   244  			cmtHash := ""
   245  			if commit != nil {
   246  				h, err := commit.HashOf()
   247  				if err != nil {
   248  					return ws, "", noConflictsOrViolations, threeWayMerge, "", err // unlikely.
   249  				}
   250  				cmtHash = h.String()
   251  			}
   252  
   253  			return ws, cmtHash, noConflictsOrViolations, threeWayMerge, "merge successful", nil
   254  		}
   255  
   256  		ws, err = executeFFMerge(ctx, dbName, spec.Squash, ws, dbData, spec.MergeC, spec)
   257  		if err != nil {
   258  			return ws, "", noConflictsOrViolations, fastForwardMerge, "", err
   259  		}
   260  		h, err := spec.MergeC.HashOf()
   261  		if err != nil {
   262  			return ws, "", noConflictsOrViolations, fastForwardMerge, "", err
   263  		}
   264  		return ws, h.String(), noConflictsOrViolations, fastForwardMerge, "merge successful", nil
   265  	}
   266  
   267  	dbState, ok, err := sess.LookupDbState(ctx, dbName)
   268  	if err != nil {
   269  		return ws, "", noConflictsOrViolations, threeWayMerge, "", err
   270  	} else if !ok {
   271  		return ws, "", noConflictsOrViolations, threeWayMerge, "", sql.ErrDatabaseNotFound.New(dbName)
   272  	}
   273  
   274  	ws, err = executeMerge(ctx, sess, dbName, spec.Squash, spec.Force, spec.HeadC, spec.MergeC, spec.MergeCSpecStr, ws, dbState.EditOpts(), spec.WorkingDiffs)
   275  	if err == doltdb.ErrUnresolvedConflictsOrViolations {
   276  		// if there are unresolved conflicts, write the resulting working set back to the session and return an
   277  		// error message
   278  		wsErr := sess.SetWorkingSet(ctx, dbName, ws)
   279  		if wsErr != nil {
   280  			return ws, "", hasConflictsOrViolations, threeWayMerge, "", wsErr
   281  		}
   282  
   283  		ctx.Warn(DoltMergeWarningCode, err.Error())
   284  		return ws, "", hasConflictsOrViolations, threeWayMerge, err.Error(), nil
   285  	} else if err != nil {
   286  		return ws, "", noConflictsOrViolations, threeWayMerge, "", err
   287  	}
   288  
   289  	err = sess.SetWorkingSet(ctx, dbName, ws)
   290  	if err != nil {
   291  		return ws, "", noConflictsOrViolations, threeWayMerge, "", err
   292  	}
   293  
   294  	var commit string
   295  	if !noCommit {
   296  		author := fmt.Sprintf("%s <%s>", spec.Name, spec.Email)
   297  		args := []string{"-m", msg, "--author", author}
   298  		if spec.Force {
   299  			args = append(args, "--force")
   300  		}
   301  		commit, _, err = doDoltCommit(ctx, args)
   302  		if err != nil {
   303  			return ws, commit, noConflictsOrViolations, threeWayMerge, "", err
   304  		}
   305  	}
   306  
   307  	return ws, commit, noConflictsOrViolations, threeWayMerge, "merge successful", nil
   308  }
   309  
   310  func executeMerge(
   311  	ctx *sql.Context,
   312  	sess *dsess.DoltSession,
   313  	dbName string,
   314  	squash bool,
   315  	force bool,
   316  	head, cm *doltdb.Commit,
   317  	cmSpec string,
   318  	ws *doltdb.WorkingSet,
   319  	opts editor.Options,
   320  	workingDiffs map[string]hash.Hash,
   321  ) (*doltdb.WorkingSet, error) {
   322  	result, err := merge.MergeCommits(ctx, head, cm, opts)
   323  	if err != nil {
   324  		switch err {
   325  		case doltdb.ErrUpToDate:
   326  			return nil, errors.New("Already up to date.")
   327  		case merge.ErrFastForward:
   328  			panic("fast forward merge")
   329  		default:
   330  			return nil, err
   331  		}
   332  	}
   333  	return mergeRootToWorking(ctx, sess, dbName, squash, force, ws, result, workingDiffs, cm, cmSpec)
   334  }
   335  
   336  func executeFFMerge(ctx *sql.Context, dbName string, squash bool, ws *doltdb.WorkingSet, dbData env.DbData, cm2 *doltdb.Commit, spec *merge.MergeSpec) (*doltdb.WorkingSet, error) {
   337  	stagedRoot, err := cm2.GetRootValue(ctx)
   338  	if err != nil {
   339  		return ws, err
   340  	}
   341  	workingRoot := stagedRoot
   342  	if len(spec.WorkingDiffs) > 0 {
   343  		workingRoot, err = applyChanges(ctx, stagedRoot, spec.WorkingDiffs)
   344  		if err != nil {
   345  			return ws, err
   346  		}
   347  	}
   348  
   349  	// TODO: This is all incredibly suspect, needs to be replaced with library code that is functional instead of
   350  	//  altering global state
   351  	if !squash {
   352  		headRef, err := dbData.Rsr.CWBHeadRef()
   353  		if err != nil {
   354  			return nil, err
   355  		}
   356  		err = dbData.Ddb.FastForward(ctx, headRef, cm2)
   357  		if err != nil {
   358  			return ws, err
   359  		}
   360  	}
   361  
   362  	ws = ws.WithWorkingRoot(workingRoot).WithStagedRoot(stagedRoot)
   363  
   364  	// We need to assign the working set to the session but ensure that its state is not labeled as dirty (ffs are clean
   365  	// merges). Hence, we go ahead and commit the working set to the transaction.
   366  	sess := dsess.DSessFromSess(ctx.Session)
   367  
   368  	err = sess.SetWorkingSet(ctx, dbName, ws)
   369  	if err != nil {
   370  		return ws, err
   371  	}
   372  
   373  	// We only fully commit our transaction when we are not squashing.
   374  	if !squash {
   375  		err = sess.CommitWorkingSet(ctx, dbName, sess.GetTransaction())
   376  		if err != nil {
   377  			return ws, err
   378  		}
   379  	}
   380  
   381  	return ws, nil
   382  }
   383  
   384  // executeNoFFMerge is a helper function for performing a merge that is not a fast-forward merge. It returns the new
   385  // working set, the resulting commit, and an error. If the error is nil, the commit will be non-nil.
   386  func executeNoFFMerge(
   387  	ctx *sql.Context,
   388  	dSess *dsess.DoltSession,
   389  	spec *merge.MergeSpec,
   390  	msg string,
   391  	dbName string,
   392  	ws *doltdb.WorkingSet,
   393  	noCommit bool,
   394  ) (*doltdb.WorkingSet, *doltdb.Commit, error) {
   395  	mergeRoot, err := spec.MergeC.GetRootValue(ctx)
   396  	if err != nil {
   397  		return nil, nil, err
   398  	}
   399  	result := &merge.Result{Root: mergeRoot, Stats: make(map[string]*merge.MergeStats)}
   400  
   401  	ws, err = mergeRootToWorking(ctx, dSess, dbName, false, spec.Force, ws, result, spec.WorkingDiffs, spec.MergeC, spec.MergeCSpecStr)
   402  	if err != nil {
   403  		// This error is recoverable, so we return a working set value along with the error
   404  		return ws, nil, err
   405  	}
   406  
   407  	// Save our work so far in the session, as it will be referenced by the commit call below (badly in need of a
   408  	// refactoring)
   409  	err = dSess.SetWorkingSet(ctx, dbName, ws)
   410  	if err != nil {
   411  		return nil, nil, err
   412  	}
   413  
   414  	// The roots need refreshing after the above
   415  	roots, _ := dSess.GetRoots(ctx, dbName)
   416  
   417  	if noCommit {
   418  		// stage all changes
   419  		roots, err = actions.StageAllTables(ctx, roots, true)
   420  		if err != nil {
   421  			return nil, nil, err
   422  		}
   423  
   424  		err = dSess.SetRoots(ctx, dbName, roots)
   425  		if err != nil {
   426  			return nil, nil, err
   427  		}
   428  
   429  		return ws.WithStagedRoot(roots.Staged), nil, nil
   430  	}
   431  
   432  	pendingCommit, err := dSess.NewPendingCommit(ctx, dbName, roots, actions.CommitStagedProps{
   433  		Message: msg,
   434  		Date:    spec.Date,
   435  		Force:   spec.Force,
   436  		Name:    spec.Name,
   437  		Email:   spec.Email,
   438  	})
   439  	if err != nil {
   440  		return nil, nil, err
   441  	}
   442  
   443  	if pendingCommit == nil {
   444  		return nil, nil, errors.New("nothing to commit")
   445  	}
   446  
   447  	commit, err := dSess.DoltCommit(ctx, dbName, dSess.GetTransaction(), pendingCommit)
   448  	if err != nil {
   449  		return nil, nil, err
   450  	}
   451  
   452  	return ws, commit, nil
   453  }
   454  
   455  func createMergeSpec(ctx *sql.Context, sess *dsess.DoltSession, dbName string, apr *argparser.ArgParseResults, commitSpecStr string) (*merge.MergeSpec, error) {
   456  	ddb, ok := sess.GetDoltDB(ctx, dbName)
   457  
   458  	dbData, ok := sess.GetDbData(ctx, dbName)
   459  
   460  	name, email, err := getNameAndEmail(ctx, apr)
   461  	if err != nil {
   462  		return nil, err
   463  	}
   464  
   465  	t := ctx.QueryTime()
   466  	if commitTimeStr, ok := apr.GetValue(cli.DateParam); ok {
   467  		t, err = dconfig.ParseDate(commitTimeStr)
   468  		if err != nil {
   469  			return nil, err
   470  		}
   471  	}
   472  
   473  	roots, ok := sess.GetRoots(ctx, dbName)
   474  	if !ok {
   475  		return nil, sql.ErrDatabaseNotFound.New(dbName)
   476  	}
   477  
   478  	if apr.Contains(cli.NoCommitFlag) && apr.Contains(cli.CommitFlag) {
   479  		return nil, errors.New("cannot define both 'commit' and 'no-commit' flags at the same time")
   480  	}
   481  	return merge.NewMergeSpec(
   482  		ctx,
   483  		dbData.Rsr,
   484  		ddb,
   485  		roots,
   486  		name,
   487  		email,
   488  		commitSpecStr,
   489  		t,
   490  		merge.WithSquash(apr.Contains(cli.SquashParam)),
   491  		merge.WithNoFF(apr.Contains(cli.NoFFParam)),
   492  		merge.WithForce(apr.Contains(cli.ForceFlag)),
   493  		merge.WithNoCommit(apr.Contains(cli.NoCommitFlag)),
   494  		merge.WithNoEdit(apr.Contains(cli.NoEditFlag)),
   495  	)
   496  }
   497  
   498  func getNameAndEmail(ctx *sql.Context, apr *argparser.ArgParseResults) (string, string, error) {
   499  	var err error
   500  	var name, email string
   501  	if authorStr, ok := apr.GetValue(cli.AuthorParam); ok {
   502  		name, email, err = cli.ParseAuthor(authorStr)
   503  		if err != nil {
   504  			return "", "", err
   505  		}
   506  	} else {
   507  		name = ctx.Client().User
   508  		email = fmt.Sprintf("%s@%s", ctx.Client().User, ctx.Client().Address)
   509  	}
   510  	return name, email, nil
   511  }
   512  
   513  func mergeRootToWorking(
   514  	ctx *sql.Context,
   515  	dSess *dsess.DoltSession,
   516  	dbName string,
   517  	squash, force bool,
   518  	ws *doltdb.WorkingSet,
   519  	merged *merge.Result,
   520  	workingDiffs map[string]hash.Hash,
   521  	cm2 *doltdb.Commit,
   522  	cm2Spec string,
   523  ) (*doltdb.WorkingSet, error) {
   524  	var err error
   525  	staged, working := merged.Root, merged.Root
   526  	if len(workingDiffs) > 0 {
   527  		working, err = applyChanges(ctx, working, workingDiffs)
   528  		if err != nil {
   529  			return ws, err
   530  		}
   531  	}
   532  
   533  	if !squash || merged.HasSchemaConflicts() {
   534  		ws = ws.StartMerge(cm2, cm2Spec)
   535  		tt := merge.SchemaConflictTableNames(merged.SchemaConflicts)
   536  		ws = ws.WithUnmergableTables(tt)
   537  	}
   538  
   539  	ws = ws.WithWorkingRoot(working)
   540  	if !merged.HasMergeArtifacts() && !force {
   541  		ws = ws.WithStagedRoot(staged)
   542  	}
   543  
   544  	err = dSess.SetWorkingSet(ctx, dbName, ws)
   545  	if err != nil {
   546  		return nil, err
   547  	}
   548  
   549  	if merged.HasMergeArtifacts() && !force {
   550  		// this error is recoverable in-session, so we return the new ws along with the error
   551  		return ws, doltdb.ErrUnresolvedConflictsOrViolations
   552  	}
   553  
   554  	return ws, nil
   555  }
   556  
   557  func applyChanges(ctx *sql.Context, root doltdb.RootValue, workingDiffs map[string]hash.Hash) (doltdb.RootValue, error) {
   558  	var err error
   559  	for tblName, h := range workingDiffs {
   560  		root, err = root.SetTableHash(ctx, tblName, h)
   561  
   562  		if err != nil {
   563  			return nil, fmt.Errorf("failed to update table; %w", err)
   564  		}
   565  	}
   566  
   567  	return root, nil
   568  }