github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_checkout.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dprocedures
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/types"
    23  
    24  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    25  	"github.com/dolthub/dolt/go/cmd/dolt/errhand"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/ref"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    31  	"github.com/dolthub/dolt/go/libraries/utils/argparser"
    32  	"github.com/dolthub/dolt/go/store/hash"
    33  )
    34  
    35  var ErrEmptyBranchName = errors.New("error: cannot checkout empty string")
    36  
    37  var doltCheckoutSchema = []*sql.Column{
    38  	{
    39  		Name:     "status",
    40  		Type:     types.Int64,
    41  		Nullable: false,
    42  	},
    43  	{
    44  		Name:     "message",
    45  		Type:     types.LongText,
    46  		Nullable: true,
    47  	},
    48  }
    49  
    50  // doltCheckout is the stored procedure version for the CLI command `dolt checkout`.
    51  func doltCheckout(ctx *sql.Context, args ...string) (sql.RowIter, error) {
    52  	res, message, err := doDoltCheckout(ctx, args)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	return rowToIter(int64(res), message), nil
    57  }
    58  
    59  func doDoltCheckout(ctx *sql.Context, args []string) (statusCode int, successMessage string, err error) {
    60  	currentDbName := ctx.GetCurrentDatabase()
    61  	if len(currentDbName) == 0 {
    62  		return 1, "", fmt.Errorf("Empty database name.")
    63  	}
    64  
    65  	argParser := cli.CreateCheckoutArgParser()
    66  	// The --move flag is used internally by the `dolt checkout` CLI command. It is not intended for external use.
    67  	// It mimics the behavior of the `dolt checkout` command line, moving the working set into the new branch.
    68  	argParser.SupportsFlag(cli.MoveFlag, "m", "")
    69  	apr, err := argParser.Parse(args)
    70  	if err != nil {
    71  		return 1, "", err
    72  	}
    73  
    74  	newBranch, _, err := parseBranchArgs(apr)
    75  	if err != nil {
    76  		return 1, "", err
    77  	}
    78  
    79  	branchOrTrack := newBranch != "" || apr.Contains(cli.TrackFlag)
    80  	if apr.Contains(cli.TrackFlag) && apr.NArg() > 0 {
    81  		return 1, "", errors.New("Improper usage. Too many arguments provided.")
    82  	}
    83  	if (branchOrTrack && apr.NArg() > 1) || (!branchOrTrack && apr.NArg() == 0) {
    84  		return 1, "", errors.New("Improper usage.")
    85  	}
    86  
    87  	dSess := dsess.DSessFromSess(ctx.Session)
    88  	dbData, ok := dSess.GetDbData(ctx, currentDbName)
    89  	if !ok {
    90  		return 1, "", fmt.Errorf("Could not load database %s", currentDbName)
    91  	}
    92  
    93  	// Prevent the -b option from being used to create new branches on read-only databases
    94  	readOnlyDatabase, err := isReadOnlyDatabase(ctx, currentDbName)
    95  	if err != nil {
    96  		return 1, "", err
    97  	}
    98  	if newBranch != "" && readOnlyDatabase {
    99  		return 1, "", fmt.Errorf("unable to create new branch in a read-only database")
   100  	}
   101  
   102  	updateHead := apr.Contains(cli.MoveFlag)
   103  
   104  	var rsc doltdb.ReplicationStatusController
   105  
   106  	// Checking out new branch.
   107  	if branchOrTrack {
   108  		newBranch, upstream, err := checkoutNewBranch(ctx, currentDbName, dbData, apr, &rsc, updateHead)
   109  		if err != nil {
   110  			return 1, "", err
   111  		} else {
   112  			return 0, generateSuccessMessage(newBranch, upstream), nil
   113  		}
   114  	}
   115  
   116  	branchName := apr.Arg(0)
   117  	if len(branchName) == 0 {
   118  		return 1, "", ErrEmptyBranchName
   119  	}
   120  
   121  	isModification, err := willModifyDb(dSess, dbData, currentDbName, branchName, updateHead)
   122  	if err != nil {
   123  		return 1, "", err
   124  	}
   125  	if !isModification {
   126  		return 0, fmt.Sprintf("Already on branch '%s'", branchName), nil
   127  	}
   128  
   129  	// Check if user wants to checkout branch.
   130  	if isBranch, err := actions.IsBranch(ctx, dbData.Ddb, branchName); err != nil {
   131  		return 1, "", err
   132  	} else if isBranch {
   133  		err = checkoutExistingBranch(ctx, currentDbName, branchName, apr)
   134  		if errors.Is(err, doltdb.ErrWorkingSetNotFound) {
   135  			// If there is a branch but there is no working set,
   136  			// somehow the local branch ref was created without a
   137  			// working set. This happened with old versions of dolt
   138  			// when running as a read replica, for example. Try to
   139  			// create a working set pointing at the existing branch
   140  			// HEAD and check out the branch again.
   141  			//
   142  			// TODO: This is all quite racey, but so is the
   143  			// handling in DoltDB, etc.
   144  			err = createWorkingSetForLocalBranch(ctx, dbData.Ddb, branchName)
   145  			if err != nil {
   146  				return 1, "", err
   147  			}
   148  
   149  			// Since we've created new refs since the transaction began, we need to commit this transaction and
   150  			// start a new one to avoid not found errors after this
   151  			// TODO: this is much worse than other places we do this, because it's two layers of implicit behavior
   152  			sess := dsess.DSessFromSess(ctx.Session)
   153  			err = commitTransaction(ctx, sess, &rsc)
   154  			if err != nil {
   155  				return 1, "", err
   156  			}
   157  
   158  			err = checkoutExistingBranch(ctx, currentDbName, branchName, apr)
   159  		}
   160  		if err != nil {
   161  			return 1, "", err
   162  		}
   163  		return 0, generateSuccessMessage(branchName, ""), nil
   164  	}
   165  
   166  	roots, ok := dSess.GetRoots(ctx, currentDbName)
   167  	if !ok {
   168  		return 1, "", fmt.Errorf("Could not load database %s", currentDbName)
   169  	}
   170  
   171  	// Check if the user executed `dolt checkout .`
   172  	if apr.NArg() == 1 && apr.Arg(0) == "." {
   173  		headRef, err := dbData.Rsr.CWBHeadRef()
   174  		if err != nil {
   175  			return 1, "", err
   176  		}
   177  
   178  		ws, err := dSess.WorkingSet(ctx, currentDbName)
   179  		if err != nil {
   180  			return 1, "", err
   181  		}
   182  		doltDb, hasDb := dSess.GetDoltDB(ctx, currentDbName)
   183  		if !hasDb {
   184  			return 1, "", errors.New("Unable to load database")
   185  		}
   186  		err = actions.ResetHard(ctx, dbData, doltDb, dSess.Username(), dSess.Email(), "", roots, headRef, ws)
   187  		if err != nil {
   188  			return 1, "", err
   189  		}
   190  		return 0, "", err
   191  	}
   192  
   193  	err = checkoutTables(ctx, roots, currentDbName, apr.Args)
   194  	if err != nil && apr.NArg() == 1 {
   195  		upstream, err := checkoutRemoteBranch(ctx, dSess, currentDbName, dbData, branchName, apr, &rsc)
   196  		if err != nil {
   197  			return 1, "", err
   198  		}
   199  		successMessage = generateSuccessMessage(branchName, upstream)
   200  	}
   201  
   202  	dsess.WaitForReplicationController(ctx, rsc)
   203  
   204  	return 0, successMessage, nil
   205  }
   206  
   207  // parseBranchArgs returns the name of the new branch and whether or not it should be created forcibly. This asserts
   208  // that the provided branch name may not be empty, so an empty string is returned where no -b or -B flag is provided.
   209  func parseBranchArgs(apr *argparser.ArgParseResults) (newBranch string, createBranchForcibly bool, err error) {
   210  	if apr.Contains(cli.CheckoutCreateBranch) && apr.Contains(cli.CreateResetBranch) {
   211  		return "", false, errors.New("Improper usage. Cannot use both -b and -B.")
   212  	}
   213  
   214  	if newBranch, ok := apr.GetValue(cli.CheckoutCreateBranch); ok {
   215  		if len(newBranch) == 0 {
   216  			return "", false, ErrEmptyBranchName
   217  		}
   218  		return newBranch, false, nil
   219  	}
   220  
   221  	if newBranch, ok := apr.GetValue(cli.CreateResetBranch); ok {
   222  		if len(newBranch) == 0 {
   223  			return "", false, ErrEmptyBranchName
   224  		}
   225  		return newBranch, true, nil
   226  	}
   227  
   228  	return "", false, nil
   229  }
   230  
   231  // isReadOnlyDatabase returns true if the named database is a read-only database. An error is returned
   232  // if any issues are encountered while looking up the named database.
   233  func isReadOnlyDatabase(ctx *sql.Context, dbName string) (bool, error) {
   234  	doltSession := dsess.DSessFromSess(ctx.Session)
   235  	db, err := doltSession.Provider().Database(ctx, dbName)
   236  	if err != nil {
   237  		return false, err
   238  	}
   239  
   240  	rodb, ok := db.(sql.ReadOnlyDatabase)
   241  	return ok && rodb.IsReadOnly(), nil
   242  }
   243  
   244  // createWorkingSetForLocalBranch will make a new working set for a local
   245  // branch ref if one does not already exist. Can be used to fix up local branch
   246  // state when branches have been created without working sets in the past.
   247  //
   248  // This makes it so that dolt_checkout can checkout workingset-less branches,
   249  // the same as `dolt checkout` at the CLI. The semantics of exactly what
   250  // working set gets created in the new case are different, since the CLI takes
   251  // the working set with it.
   252  //
   253  // TODO: This is cribbed heavily from doltdb.*DoltDB.NewBranchAtCommit.
   254  func createWorkingSetForLocalBranch(ctx *sql.Context, ddb *doltdb.DoltDB, branchName string) error {
   255  	branchRef := ref.NewBranchRef(branchName)
   256  	commit, err := ddb.ResolveCommitRef(ctx, branchRef)
   257  	if err != nil {
   258  		return err
   259  	}
   260  
   261  	commitRoot, err := commit.GetRootValue(ctx)
   262  	if err != nil {
   263  		return err
   264  	}
   265  
   266  	wsRef, err := ref.WorkingSetRefForHead(branchRef)
   267  	if err != nil {
   268  		return err
   269  	}
   270  
   271  	_, err = ddb.ResolveWorkingSet(ctx, wsRef)
   272  	if err == nil {
   273  		// This already exists. Return...
   274  		return nil
   275  	}
   276  	if !errors.Is(err, doltdb.ErrWorkingSetNotFound) {
   277  		return err
   278  	}
   279  
   280  	ws := doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(commitRoot).WithStagedRoot(commitRoot)
   281  	return ddb.UpdateWorkingSet(ctx, wsRef, ws, hash.Hash{} /* current hash... */, doltdb.TodoWorkingSetMeta(), nil)
   282  }
   283  
   284  // checkoutRemoteBranch checks out a remote branch creating a new local branch with the same name as the remote branch
   285  // and set its upstream. The upstream persists out of sql session. Returns the name of the upstream remote and branch.
   286  func checkoutRemoteBranch(ctx *sql.Context, dSess *dsess.DoltSession, dbName string, dbData env.DbData, branchName string, apr *argparser.ArgParseResults, rsc *doltdb.ReplicationStatusController) (upstream string, err error) {
   287  	remoteRefs, err := actions.GetRemoteBranchRef(ctx, dbData.Ddb, branchName)
   288  	if err != nil {
   289  		return "", errors.New("fatal: unable to read from data repository")
   290  	}
   291  
   292  	if len(remoteRefs) == 0 {
   293  		if doltdb.IsValidCommitHash(branchName) && apr.Contains(cli.MoveFlag) {
   294  
   295  			// User tried to enter a detached head state, which we don't support.
   296  			// Inform and suggest that they check-out a new branch at this commit instead.
   297  
   298  			return "", fmt.Errorf(`dolt does not support a detached head state. To create a branch at this commit instead, run:
   299  
   300  	dolt checkout %s -b {new_branch_name}
   301  `, branchName)
   302  		}
   303  		return "", fmt.Errorf("error: could not find %s", branchName)
   304  	} else if len(remoteRefs) == 1 {
   305  		remoteRef := remoteRefs[0]
   306  		err = actions.CreateBranchWithStartPt(ctx, dbData, branchName, remoteRef.String(), false, rsc)
   307  		if err != nil {
   308  			return "", err
   309  		}
   310  
   311  		// We need to commit the transaction here or else the branch we just created isn't visible to the current transaction,
   312  		// and we are about to switch to it. So set the new branch head for the new transaction, then commit this one
   313  		sess := dsess.DSessFromSess(ctx.Session)
   314  		err = commitTransaction(ctx, sess, rsc)
   315  		if err != nil {
   316  			return "", err
   317  		}
   318  
   319  		err = checkoutExistingBranch(ctx, dbName, branchName, apr)
   320  		if err != nil {
   321  			return "", err
   322  		}
   323  
   324  		// After checking out a new branch, we need to reload the database.
   325  		dbData, ok := dSess.GetDbData(ctx, dbName)
   326  		if !ok {
   327  			return "", fmt.Errorf("Could not reload database %s", dbName)
   328  		}
   329  
   330  		refSpec, err := ref.ParseRefSpecForRemote(remoteRef.GetRemote(), remoteRef.GetBranch())
   331  		if err != nil {
   332  			return "", errhand.BuildDError(fmt.Errorf("%w: '%s'", err, remoteRef.GetRemote()).Error()).Build()
   333  		}
   334  
   335  		headRef, err := dbData.Rsr.CWBHeadRef()
   336  		if err != nil {
   337  			return "", err
   338  		}
   339  
   340  		err = env.SetRemoteUpstreamForRefSpec(dbData.Rsw, refSpec, remoteRef.GetRemote(), headRef)
   341  		if err != nil {
   342  			return "", err
   343  		}
   344  
   345  		return remoteRef.GetPath(), nil
   346  	} else {
   347  		return "", fmt.Errorf("'%s' matched multiple (%v) remote tracking branches", branchName, len(remoteRefs))
   348  	}
   349  }
   350  
   351  // checkoutNewBranch creates a new branch and makes it the active branch for the session.
   352  // If isMove is true, this function also moves the working set from the current branch into the new branch.
   353  // Returns the name of the new branch and the remote upstream branch (empty string if not applicable.)
   354  func checkoutNewBranch(ctx *sql.Context, dbName string, dbData env.DbData, apr *argparser.ArgParseResults, rsc *doltdb.ReplicationStatusController, isMove bool) (newBranchName string, remoteAndBranch string, err error) {
   355  	var remoteName, remoteBranchName string
   356  	var startPt = "head"
   357  	var refSpec ref.RefSpec
   358  
   359  	if apr.NArg() == 1 {
   360  		startPt = apr.Arg(0)
   361  	}
   362  
   363  	trackVal, setTrackUpstream := apr.GetValue(cli.TrackFlag)
   364  	if setTrackUpstream {
   365  		if trackVal == "inherit" {
   366  			return "", "", fmt.Errorf("--track='inherit' is not supported yet")
   367  		} else if trackVal != "direct" {
   368  			startPt = trackVal
   369  		}
   370  		remoteName, remoteBranchName = actions.ParseRemoteBranchName(startPt)
   371  		refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranchName)
   372  		if err != nil {
   373  			return "", "", err
   374  		}
   375  		newBranchName = remoteBranchName
   376  	}
   377  
   378  	// A little wonky behavior here. parseBranchArgs is actually called twice because in this procedure we pass around
   379  	// the parse results, but we also needed to parse the -b and -B flags in the main procedure. It ended up being
   380  	// a little cleaner to just call it again here than to pass the results around.
   381  	var createBranchForcibly bool
   382  	var optionBBranch string
   383  	optionBBranch, createBranchForcibly, err = parseBranchArgs(apr)
   384  	if err != nil {
   385  		return "", "", err
   386  	}
   387  	if optionBBranch != "" {
   388  		newBranchName = optionBBranch
   389  	}
   390  
   391  	err = actions.CreateBranchWithStartPt(ctx, dbData, newBranchName, startPt, createBranchForcibly, rsc)
   392  	if err != nil {
   393  		return "", "", err
   394  	}
   395  
   396  	if setTrackUpstream {
   397  		err = env.SetRemoteUpstreamForRefSpec(dbData.Rsw, refSpec, remoteName, ref.NewBranchRef(newBranchName))
   398  		if err != nil {
   399  			return "", "", err
   400  		}
   401  	} else if autoSetupMerge, err := loadConfig(ctx).GetString("branch.autosetupmerge"); err != nil || autoSetupMerge != "false" {
   402  		remoteName, remoteBranchName = actions.ParseRemoteBranchName(startPt)
   403  		refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranchName)
   404  		if err == nil {
   405  			err = env.SetRemoteUpstreamForRefSpec(dbData.Rsw, refSpec, remoteName, ref.NewBranchRef(newBranchName))
   406  			if err != nil {
   407  				return "", "", err
   408  			}
   409  		}
   410  	}
   411  
   412  	// We need to commit the transaction here or else the branch we just created isn't visible to the current transaction,
   413  	// and we are about to switch to it. So set the new branch head for the new transaction, then commit this one
   414  	sess := dsess.DSessFromSess(ctx.Session)
   415  	err = commitTransaction(ctx, sess, rsc)
   416  	if err != nil {
   417  		return "", "", err
   418  	}
   419  
   420  	if remoteName != "" {
   421  		remoteAndBranch = fmt.Sprintf("%s/%s", remoteName, remoteBranchName)
   422  	}
   423  
   424  	if isMove {
   425  		return newBranchName, remoteAndBranch, doGlobalCheckout(ctx, newBranchName, apr.Contains(cli.ForceFlag), true)
   426  	} else {
   427  
   428  		wsRef, err := ref.WorkingSetRefForHead(ref.NewBranchRef(newBranchName))
   429  		if err != nil {
   430  			return "", "", err
   431  		}
   432  
   433  		err = sess.SwitchWorkingSet(ctx, dbName, wsRef)
   434  		if err != nil {
   435  			return "", "", err
   436  		}
   437  	}
   438  
   439  	return newBranchName, remoteAndBranch, nil
   440  }
   441  
   442  // checkoutExistingBranch updates the active branch reference to point to an already existing branch.
   443  func checkoutExistingBranch(ctx *sql.Context, dbName string, branchName string, apr *argparser.ArgParseResults) error {
   444  	wsRef, err := ref.WorkingSetRefForHead(ref.NewBranchRef(branchName))
   445  	if err != nil {
   446  		return err
   447  	}
   448  
   449  	if ctx.GetCurrentDatabase() != dbName {
   450  		ctx.SetCurrentDatabase(dbName)
   451  	}
   452  
   453  	dSess := dsess.DSessFromSess(ctx.Session)
   454  
   455  	if apr.Contains(cli.MoveFlag) {
   456  		return doGlobalCheckout(ctx, branchName, apr.Contains(cli.ForceFlag), false)
   457  	} else {
   458  		err = dSess.SwitchWorkingSet(ctx, dbName, wsRef)
   459  		if err != nil {
   460  			return err
   461  		}
   462  	}
   463  
   464  	return nil
   465  }
   466  
   467  // doGlobalCheckout implements the behavior of the `dolt checkout` command line, moving the working set into
   468  // the new branch and persisting the checked-out branch into future sessions
   469  func doGlobalCheckout(ctx *sql.Context, branchName string, isForce bool, isNewBranch bool) error {
   470  	err := MoveWorkingSetToBranch(ctx, branchName, isForce, isNewBranch)
   471  	if err != nil && err != doltdb.ErrAlreadyOnBranch {
   472  		return err
   473  	}
   474  
   475  	return nil
   476  }
   477  
   478  func checkoutTables(ctx *sql.Context, roots doltdb.Roots, name string, tables []string) error {
   479  	roots, err := actions.MoveTablesFromHeadToWorking(ctx, roots, tables)
   480  
   481  	if err != nil {
   482  		if doltdb.IsRootValUnreachable(err) {
   483  			rt := doltdb.GetUnreachableRootType(err)
   484  			return fmt.Errorf("error: unable to read the %s", rt.String())
   485  		} else if actions.IsTblNotExist(err) {
   486  			return fmt.Errorf("error: given tables do not exist")
   487  		} else {
   488  			return fmt.Errorf("fatal: Unexpected error checking out tables")
   489  		}
   490  	}
   491  
   492  	dSess := dsess.DSessFromSess(ctx.Session)
   493  	return dSess.SetRoots(ctx, name, roots)
   494  }