github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/dfunctions/dolt_merge.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dfunctions
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/expression"
    24  
    25  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/merge"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
    31  	"github.com/dolthub/dolt/go/libraries/utils/argparser"
    32  )
    33  
    34  const DoltMergeFuncName = "dolt_merge"
    35  
    36  type DoltMergeFunc struct {
    37  	expression.NaryExpression
    38  }
    39  
    40  func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    41  	dbName := ctx.GetCurrentDatabase()
    42  
    43  	if len(dbName) == 0 {
    44  		return 1, fmt.Errorf("Empty database name.")
    45  	}
    46  
    47  	sess := sqle.DSessFromSess(ctx.Session)
    48  	dbData, ok := sess.GetDbData(dbName)
    49  
    50  	if !ok {
    51  		return 1, fmt.Errorf("Could not load database %s", dbName)
    52  	}
    53  
    54  	ap := cli.CreateMergeArgParser()
    55  	args, err := getDoltArgs(ctx, row, d.Children())
    56  
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	apr, err := ap.Parse(args)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	if apr.ContainsAll(cli.SquashParam, cli.NoFFParam) {
    67  		return 1, fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together.\n", cli.SquashParam, cli.NoFFParam)
    68  	}
    69  
    70  	if apr.Contains(cli.AbortParam) {
    71  		if !dbData.Rsr.IsMergeActive() {
    72  			return 1, fmt.Errorf("fatal: There is no merge to abort")
    73  		}
    74  
    75  		err = abortMerge(ctx, dbData)
    76  
    77  		if err != nil {
    78  			return 1, err
    79  		}
    80  
    81  		return "Merge aborted", nil
    82  	}
    83  
    84  	// The first argument should be the branch name.
    85  	branchName := apr.Arg(0)
    86  
    87  	ddb, ok := sess.GetDoltDB(dbName)
    88  	if !ok {
    89  		return nil, sql.ErrDatabaseNotFound.New(dbName)
    90  	}
    91  
    92  	root, ok := sess.GetRoot(dbName)
    93  	if !ok {
    94  		return nil, sql.ErrDatabaseNotFound.New(dbName)
    95  	}
    96  
    97  	hasConflicts, err := root.HasConflicts(ctx)
    98  	if err != nil {
    99  		return 1, err
   100  	}
   101  
   102  	if hasConflicts {
   103  		return 1, doltdb.ErrUnresolvedConflicts
   104  	}
   105  
   106  	if dbData.Rsr.IsMergeActive() {
   107  		return 1, doltdb.ErrMergeActive
   108  	}
   109  
   110  	head, hh, headRoot, err := getHead(ctx, sess, dbName)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	err = checkForUncommittedChanges(root, headRoot)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	cm, cmh, err := getBranchCommit(ctx, branchName, ddb)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  
   125  	// No need to write a merge commit, if the head can ffw to the commit coming from the branch.
   126  	canFF, err := head.CanFastForwardTo(ctx, cm)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	if canFF {
   132  		if apr.Contains(cli.NoFFParam) {
   133  			err = executeNoFFMerge(ctx, sess, apr, dbName, dbData, head, cm)
   134  		} else {
   135  			err = executeFFMerge(ctx, apr.Contains(cli.SquashParam), dbName, dbData, cm)
   136  		}
   137  
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  		return cmh.String(), err
   142  	}
   143  
   144  	err = executeMerge(ctx, apr.Contains(cli.SquashParam), head, cm, dbName, dbData)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	returnMsg := fmt.Sprintf("Updating %s..%s", cmh.String(), hh.String())
   150  
   151  	return returnMsg, nil
   152  }
   153  
   154  func abortMerge(ctx *sql.Context, dbData env.DbData) error {
   155  	err := actions.CheckoutAllTables(ctx, dbData)
   156  
   157  	if err != nil {
   158  		return err
   159  	}
   160  
   161  	err = dbData.Rsw.AbortMerge()
   162  	if err != nil {
   163  		return err
   164  	}
   165  
   166  	hh, err := dbData.Rsr.CWBHeadHash(ctx)
   167  	if err != nil {
   168  		return err
   169  	}
   170  
   171  	return setHeadAndWorkingSessionRoot(ctx, hh.String())
   172  }
   173  
   174  func executeMerge(ctx *sql.Context, squash bool, head, cm *doltdb.Commit, name string, dbData env.DbData) error {
   175  	mergeRoot, mergeStats, err := merge.MergeCommits(ctx, head, cm)
   176  
   177  	if err != nil {
   178  		switch err {
   179  		case doltdb.ErrUpToDate:
   180  			return errors.New("Already up to date.")
   181  		case merge.ErrFastForward:
   182  			panic("fast forward merge")
   183  		default:
   184  			return errors.New("Bad merge")
   185  		}
   186  	}
   187  
   188  	return mergeRootToWorking(ctx, squash, name, dbData, mergeRoot, cm, mergeStats)
   189  }
   190  
   191  func executeFFMerge(ctx *sql.Context, squash bool, dbName string, dbData env.DbData, cm2 *doltdb.Commit) error {
   192  	rv, err := cm2.GetRootValue()
   193  
   194  	if err != nil {
   195  		return errors.New("Failed to return root value.")
   196  	}
   197  
   198  	stagedHash, err := dbData.Ddb.WriteRootValue(ctx, rv)
   199  
   200  	if err != nil {
   201  		return err
   202  	}
   203  
   204  	workingHash := stagedHash
   205  	if !squash {
   206  		err = dbData.Ddb.FastForward(ctx, dbData.Rsr.CWBHeadRef(), cm2)
   207  
   208  		if err != nil {
   209  			return err
   210  		}
   211  	}
   212  
   213  	err = dbData.Rsw.SetWorkingHash(ctx, workingHash)
   214  	if err != nil {
   215  		return err
   216  	}
   217  
   218  	err = dbData.Rsw.SetStagedHash(ctx, stagedHash)
   219  	if err != nil {
   220  		return err
   221  	}
   222  
   223  	hh, err := dbData.Rsr.CWBHeadHash(ctx)
   224  	if err != nil {
   225  		return err
   226  	}
   227  
   228  	if squash {
   229  		return ctx.SetSessionVariable(ctx, sqle.WorkingKey(dbName), workingHash.String())
   230  	} else {
   231  		return setHeadAndWorkingSessionRoot(ctx, hh.String())
   232  	}
   233  }
   234  
   235  func executeNoFFMerge(
   236  	ctx *sql.Context,
   237  	dSess *sqle.DoltSession,
   238  	apr *argparser.ArgParseResults,
   239  	dbName string,
   240  	dbData env.DbData,
   241  	pr, cm2 *doltdb.Commit,
   242  ) error {
   243  	mergedRoot, err := cm2.GetRootValue()
   244  	if err != nil {
   245  		return errors.New("Failed to return root value.")
   246  	}
   247  
   248  	err = mergeRootToWorking(ctx, false, dbName, dbData, mergedRoot, cm2, map[string]*merge.MergeStats{})
   249  	if err != nil {
   250  		return err
   251  	}
   252  
   253  	msg, msgOk := apr.GetValue(cli.CommitMessageArg)
   254  	if !msgOk {
   255  		hh, err := pr.HashOf()
   256  		if err != nil {
   257  			return err
   258  		}
   259  
   260  		cmh, err := cm2.HashOf()
   261  		if err != nil {
   262  			return err
   263  		}
   264  
   265  		msg = fmt.Sprintf("SQL Generated commit merging %s into %s", hh.String(), cmh.String())
   266  	}
   267  
   268  	var name, email string
   269  	if authorStr, ok := apr.GetValue(cli.AuthorParam); ok {
   270  		name, email, err = cli.ParseAuthor(authorStr)
   271  		if err != nil {
   272  			return err
   273  		}
   274  	} else {
   275  		name = dSess.Username
   276  		email = dSess.Email
   277  	}
   278  
   279  	// Specify the time if the date parameter is not.
   280  	t := ctx.QueryTime()
   281  	if commitTimeStr, ok := apr.GetValue(cli.DateParam); ok {
   282  		var err error
   283  		t, err = cli.ParseDate(commitTimeStr)
   284  		if err != nil {
   285  			return err
   286  		}
   287  	}
   288  
   289  	h, err := actions.CommitStaged(ctx, dbData, actions.CommitStagedProps{
   290  		Message:          msg,
   291  		Date:             t,
   292  		AllowEmpty:       apr.Contains(cli.AllowEmptyFlag),
   293  		CheckForeignKeys: !apr.Contains(cli.ForceFlag),
   294  		Name:             name,
   295  		Email:            email,
   296  	})
   297  
   298  	if err != nil {
   299  		return err
   300  	}
   301  
   302  	return setHeadAndWorkingSessionRoot(ctx, h)
   303  }
   304  
   305  func mergeRootToWorking(
   306  	ctx *sql.Context,
   307  	squash bool,
   308  	dbName string,
   309  	dbData env.DbData,
   310  	mergedRoot *doltdb.RootValue,
   311  	cm2 *doltdb.Commit,
   312  	mergeStats map[string]*merge.MergeStats,
   313  ) error {
   314  	h2, err := cm2.HashOf()
   315  	if err != nil {
   316  		return err
   317  	}
   318  
   319  	workingRoot := mergedRoot
   320  	if !squash {
   321  		err = dbData.Rsw.StartMerge(h2.String())
   322  
   323  		if err != nil {
   324  			return err
   325  		}
   326  	}
   327  
   328  	workingHash, err := env.UpdateWorkingRoot(ctx, dbData.Ddb, dbData.Rsw, workingRoot)
   329  	if err != nil {
   330  		return err
   331  	}
   332  
   333  	hasConflicts := checkForConflicts(mergeStats)
   334  
   335  	if hasConflicts {
   336  		// If there are conflicts write them to the working root anyway too allow for merge resolution via the dolt_conflicts
   337  		// table.
   338  		err := ctx.SetSessionVariable(ctx, sqle.WorkingKey(dbName), workingHash.String())
   339  		if err != nil {
   340  			return err
   341  		}
   342  
   343  		return doltdb.ErrUnresolvedConflicts
   344  	}
   345  
   346  	_, err = env.UpdateStagedRoot(ctx, dbData.Ddb, dbData.Rsw, workingRoot)
   347  	if err != nil {
   348  		return err
   349  	}
   350  
   351  	return ctx.SetSessionVariable(ctx, sqle.WorkingKey(dbName), workingHash.String())
   352  }
   353  
   354  func checkForConflicts(tblToStats map[string]*merge.MergeStats) bool {
   355  	for _, stats := range tblToStats {
   356  		if stats.Operation == merge.TableModified && stats.Conflicts > 0 {
   357  			return true
   358  		}
   359  	}
   360  
   361  	return false
   362  }
   363  
   364  func (d DoltMergeFunc) String() string {
   365  	childrenStrings := make([]string, len(d.Children()))
   366  
   367  	for i, child := range d.Children() {
   368  		childrenStrings[i] = child.String()
   369  	}
   370  
   371  	return fmt.Sprintf("DOLT_MERGE(%s)", strings.Join(childrenStrings, ","))
   372  }
   373  
   374  func (d DoltMergeFunc) Type() sql.Type {
   375  	return sql.Text
   376  }
   377  
   378  func (d DoltMergeFunc) WithChildren(ctx *sql.Context, children ...sql.Expression) (sql.Expression, error) {
   379  	return NewDoltMergeFunc(ctx, children...)
   380  }
   381  
   382  func NewDoltMergeFunc(ctx *sql.Context, args ...sql.Expression) (sql.Expression, error) {
   383  	return &DoltMergeFunc{expression.NaryExpression{ChildExpressions: args}}, nil
   384  }