github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_branch.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dprocedures
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  
    24  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/ref"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
    32  	"github.com/dolthub/dolt/go/libraries/utils/argparser"
    33  	"github.com/dolthub/dolt/go/libraries/utils/filesys"
    34  )
    35  
    36  var (
    37  	EmptyBranchNameErr = errors.New("error: cannot branch empty string")
    38  	InvalidArgErr      = errors.New("error: invalid usage")
    39  )
    40  
    41  // doltBranch is the stored procedure version for the CLI command `dolt branch`.
    42  func doltBranch(ctx *sql.Context, args ...string) (sql.RowIter, error) {
    43  	res, err := doDoltBranch(ctx, args)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  	return rowToIter(int64(res)), nil
    48  }
    49  
    50  func doDoltBranch(ctx *sql.Context, args []string) (int, error) {
    51  	dbName := ctx.GetCurrentDatabase()
    52  
    53  	if len(dbName) == 0 {
    54  		return 1, fmt.Errorf("Empty database name.")
    55  	}
    56  
    57  	// CreateBranchArgParser has the common flags for the command line and the stored procedure.
    58  	// The stored procedure doesn't support all actions, so we have a shorter description for -r.
    59  	ap := cli.CreateBranchArgParser()
    60  	ap.SupportsFlag(cli.RemoteParam, "r", "Delete a remote tracking branch.")
    61  	apr, err := ap.Parse(args)
    62  	if err != nil {
    63  		return 1, err
    64  	}
    65  
    66  	dSess := dsess.DSessFromSess(ctx.Session)
    67  	dbData, ok := dSess.GetDbData(ctx, dbName)
    68  	if !ok {
    69  		return 1, fmt.Errorf("Could not load database %s", dbName)
    70  	}
    71  
    72  	var rsc doltdb.ReplicationStatusController
    73  
    74  	switch {
    75  	case apr.Contains(cli.CopyFlag):
    76  		err = copyBranch(ctx, dbData, apr, &rsc)
    77  	case apr.Contains(cli.MoveFlag):
    78  		err = renameBranch(ctx, dbData, apr, dSess, dbName, &rsc)
    79  	case apr.Contains(cli.DeleteFlag), apr.Contains(cli.DeleteForceFlag):
    80  		err = deleteBranches(ctx, dbData, apr, dSess, dbName, &rsc)
    81  	default:
    82  		err = createNewBranch(ctx, dbData, apr, &rsc)
    83  	}
    84  
    85  	if err != nil {
    86  		return 1, err
    87  	} else {
    88  		return 0, commitTransaction(ctx, dSess, &rsc)
    89  	}
    90  }
    91  
    92  func commitTransaction(ctx *sql.Context, dSess *dsess.DoltSession, rsc *doltdb.ReplicationStatusController) error {
    93  	currentTx := ctx.GetTransaction()
    94  
    95  	err := dSess.CommitTransaction(ctx, currentTx)
    96  	if err != nil {
    97  		return err
    98  	}
    99  	newTx, err := dSess.StartTransaction(ctx, sql.ReadWrite)
   100  	if err != nil {
   101  		return err
   102  	}
   103  	ctx.SetTransaction(newTx)
   104  
   105  	if rsc != nil {
   106  		dsess.WaitForReplicationController(ctx, *rsc)
   107  	}
   108  
   109  	return nil
   110  }
   111  
   112  // renameBranch takes DoltSession and database name to try accessing file system for dolt database.
   113  // If the oldBranch being renamed is the current branch on CLI, then RepoState head will be updated with the newBranch ref.
   114  func renameBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResults, sess *dsess.DoltSession, dbName string, rsc *doltdb.ReplicationStatusController) error {
   115  	if apr.NArg() != 2 {
   116  		return InvalidArgErr
   117  	}
   118  	oldBranchName, newBranchName := apr.Arg(0), apr.Arg(1)
   119  	if oldBranchName == "" || newBranchName == "" {
   120  		return EmptyBranchNameErr
   121  	}
   122  	if err := branch_control.CanDeleteBranch(ctx, oldBranchName); err != nil {
   123  		return err
   124  	}
   125  	if err := branch_control.CanCreateBranch(ctx, newBranchName); err != nil {
   126  		return err
   127  	}
   128  	force := apr.Contains(cli.ForceFlag)
   129  
   130  	if !force {
   131  		err := validateBranchNotActiveInAnySession(ctx, oldBranchName)
   132  		if err != nil {
   133  			return err
   134  		}
   135  		var headOnCLI string
   136  		fs, err := sess.Provider().FileSystemForDatabase(dbName)
   137  		if err == nil {
   138  			if repoState, err := env.LoadRepoState(fs); err == nil {
   139  				headOnCLI = repoState.Head.Ref.GetPath()
   140  			}
   141  		}
   142  		if headOnCLI == oldBranchName && sqlserver.RunningInServerMode() && !shouldAllowDefaultBranchDeletion(ctx) {
   143  			return fmt.Errorf("unable to rename branch '%s', because it is the default branch for "+
   144  				"database '%s'; this can by changed on the command line, by stopping the sql-server, "+
   145  				"running `dolt checkout <another_branch> and restarting the sql-server", oldBranchName, dbName)
   146  		}
   147  
   148  	} else if err := branch_control.CanDeleteBranch(ctx, newBranchName); err != nil {
   149  		// If force is enabled, we can overwrite the destination branch, so we require a permission check here, even if the
   150  		// destination branch doesn't exist. An unauthorized user could simply rerun the command without the force flag.
   151  		return err
   152  	}
   153  
   154  	headRef, err := dbData.Rsr.CWBHeadRef()
   155  	if err != nil {
   156  		return err
   157  	}
   158  	activeSessionBranch := headRef.GetPath()
   159  
   160  	err = actions.RenameBranch(ctx, dbData, oldBranchName, newBranchName, sess.Provider(), force, rsc)
   161  	if err != nil {
   162  		return err
   163  	}
   164  	err = branch_control.AddAdminForContext(ctx, newBranchName)
   165  	if err != nil {
   166  		return err
   167  	}
   168  
   169  	// The current branch on CLI can be deleted as user can be on different branch on SQL and delete it from SQL session.
   170  	// To update current head info on RepoState, we need DoltEnv to load CLI environment.
   171  	if fs, err := sess.Provider().FileSystemForDatabase(dbName); err == nil {
   172  		if repoState, err := env.LoadRepoState(fs); err == nil {
   173  			if repoState.Head.Ref.GetPath() == oldBranchName {
   174  				repoState.Head.Ref = ref.NewBranchRef(newBranchName)
   175  				repoState.Save(fs)
   176  			}
   177  		}
   178  	}
   179  
   180  	err = sess.RenameBranchState(ctx, dbName, oldBranchName, newBranchName)
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	// If the active branch of the SQL session was renamed, switch to the new branch.
   186  	if oldBranchName == activeSessionBranch {
   187  		wsRef, err := ref.WorkingSetRefForHead(ref.NewBranchRef(newBranchName))
   188  		if err != nil {
   189  			return err
   190  		}
   191  
   192  		err = sess.SwitchWorkingSet(ctx, dbName, wsRef)
   193  		if err != nil {
   194  			return err
   195  		}
   196  	}
   197  
   198  	return nil
   199  }
   200  
   201  // deleteBranches takes DoltSession and database name to try accessing file system for dolt database.
   202  // If the database is not session state db and the branch being deleted is the current branch on CLI, it will update
   203  // the RepoState to set head as empty branchRef.
   204  func deleteBranches(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResults, sess *dsess.DoltSession, dbName string, rsc *doltdb.ReplicationStatusController) error {
   205  	if apr.NArg() == 0 {
   206  		return InvalidArgErr
   207  	}
   208  
   209  	currBase, currBranch := dsess.SplitRevisionDbName(ctx.GetCurrentDatabase())
   210  
   211  	// The current branch on CLI can be deleted as user can be on different branch on SQL and delete it from SQL session.
   212  	// To update current head info on RepoState, we need DoltEnv to load CLI environment.
   213  	var headOnCLI string
   214  	fs, err := sess.Provider().FileSystemForDatabase(dbName)
   215  	if err == nil {
   216  		if repoState, err := env.LoadRepoState(fs); err == nil {
   217  			headOnCLI = repoState.Head.Ref.GetPath()
   218  		}
   219  	}
   220  
   221  	// Verify that we can delete all branches before continuing
   222  	for _, branchName := range apr.Args {
   223  		if err = branch_control.CanDeleteBranch(ctx, branchName); err != nil {
   224  			return err
   225  		}
   226  	}
   227  
   228  	dSess := dsess.DSessFromSess(ctx.Session)
   229  	for _, branchName := range apr.Args {
   230  		if len(branchName) == 0 {
   231  			return EmptyBranchNameErr
   232  		}
   233  
   234  		force := apr.Contains(cli.DeleteForceFlag) || apr.Contains(cli.ForceFlag)
   235  		if !force {
   236  			err = validateBranchNotActiveInAnySession(ctx, branchName)
   237  			if err != nil {
   238  				return err
   239  			}
   240  		}
   241  
   242  		// If we deleted the branch this client is connected to, change the current branch to the default
   243  		// TODO: this would be nice to do for every other session (or maybe invalidate sessions on this branch)
   244  		if strings.ToLower(currBranch) == strings.ToLower(branchName) {
   245  			ctx.SetCurrentDatabase(currBase)
   246  		}
   247  
   248  		if headOnCLI == branchName && sqlserver.RunningInServerMode() && !shouldAllowDefaultBranchDeletion(ctx) {
   249  			return fmt.Errorf("unable to delete branch '%s', because it is the default branch for "+
   250  				"database '%s'; this can by changed on the command line, by stopping the sql-server, "+
   251  				"running `dolt checkout <another_branch> and restarting the sql-server", branchName, dbName)
   252  		}
   253  
   254  		remote := apr.Contains(cli.RemoteParam)
   255  
   256  		err = actions.DeleteBranch(ctx, dbData, branchName, actions.DeleteOptions{
   257  			Force:  force,
   258  			Remote: remote,
   259  		}, dSess.Provider(), rsc)
   260  		if err != nil {
   261  			return err
   262  		}
   263  
   264  		// If the session has this branch checked out, we need to change that to the default head
   265  		headRef, err := dSess.CWBHeadRef(ctx, currBase)
   266  		if err != nil {
   267  			return err
   268  		}
   269  
   270  		if headRef == ref.NewBranchRef(branchName) {
   271  			err = dSess.RemoveBranchState(ctx, currBase, branchName)
   272  			if err != nil {
   273  				return err
   274  			}
   275  		}
   276  	}
   277  
   278  	return nil
   279  }
   280  
   281  // shouldAllowDefaultBranchDeletion returns true if the default branch deletion check should be
   282  // bypassed for testing. This should only ever be true for tests that need to invalidate a databases
   283  // default branch to test recovery from a bad state. We determine if the check should be bypassed by
   284  // looking for the presence of an undocumented dolt user var, dolt_allow_default_branch_deletion.
   285  func shouldAllowDefaultBranchDeletion(ctx *sql.Context) bool {
   286  	_, userVar, _ := ctx.Session.GetUserVariable(ctx, "dolt_allow_default_branch_deletion")
   287  	return userVar != nil
   288  }
   289  
   290  // validateBranchNotActiveInAnySessions returns an error if the specified branch is currently
   291  // selected as the active branch for any active server sessions.
   292  func validateBranchNotActiveInAnySession(ctx *sql.Context, branchName string) error {
   293  	currentDbName := ctx.GetCurrentDatabase()
   294  	currentDbName, _ = dsess.SplitRevisionDbName(currentDbName)
   295  	if currentDbName == "" {
   296  		return nil
   297  	}
   298  
   299  	if sqlserver.RunningInServerMode() == false {
   300  		return nil
   301  	}
   302  
   303  	runningServer := sqlserver.GetRunningServer()
   304  	if runningServer == nil {
   305  		return nil
   306  	}
   307  	sessionManager := runningServer.SessionManager()
   308  	branchRef := ref.NewBranchRef(branchName)
   309  
   310  	return sessionManager.Iter(func(session sql.Session) (bool, error) {
   311  		if session.ID() == ctx.Session.ID() {
   312  			return false, nil
   313  		}
   314  
   315  		sess, ok := session.(*dsess.DoltSession)
   316  		if !ok {
   317  			return false, fmt.Errorf("unexpected session type: %T", session)
   318  		}
   319  
   320  		sessionDbName := sess.Session.GetCurrentDatabase()
   321  		baseName, _ := dsess.SplitRevisionDbName(sessionDbName)
   322  		if len(baseName) == 0 || baseName != currentDbName {
   323  			return false, nil
   324  		}
   325  
   326  		activeBranchRef, err := sess.CWBHeadRef(ctx, sessionDbName)
   327  		if err != nil {
   328  			// The above will throw an error if the current DB doesn't have a head ref, in which case we don't need to
   329  			// consider it
   330  			return false, nil
   331  		}
   332  
   333  		if ref.Equals(branchRef, activeBranchRef) {
   334  			return false, fmt.Errorf("unsafe to delete or rename branches in use in other sessions; " +
   335  				"use --force to force the change")
   336  		}
   337  
   338  		return false, nil
   339  	})
   340  }
   341  
   342  // TODO: the config should be available via the context, it's unnecessary to do an env.Load here and this should be removed
   343  func loadConfig(ctx *sql.Context) *env.DoltCliConfig {
   344  	// When executing branch actions from SQL, we don't have access to a DoltEnv like we do from
   345  	// within the CLI. We can fake it here enough to get a DoltCliConfig, but we can't rely on the
   346  	// DoltEnv because tests and production will run with different settings (e.g. in-mem versus file).
   347  	dEnv := env.Load(ctx, env.GetCurrentUserHomeDir, filesys.LocalFS, doltdb.LocalDirDoltDB, "")
   348  	return dEnv.Config
   349  }
   350  
   351  func createNewBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResults, rsc *doltdb.ReplicationStatusController) error {
   352  	if apr.NArg() == 0 || apr.NArg() > 2 {
   353  		return InvalidArgErr
   354  	}
   355  
   356  	var branchName = apr.Arg(0)
   357  	var startPt = "HEAD"
   358  	if len(branchName) == 0 {
   359  		return EmptyBranchNameErr
   360  	}
   361  	if apr.NArg() == 2 {
   362  		startPt = apr.Arg(1)
   363  		if len(startPt) == 0 {
   364  			return InvalidArgErr
   365  		}
   366  	}
   367  
   368  	var remoteName, remoteBranch string
   369  	var refSpec ref.RefSpec
   370  	var err error
   371  	trackVal, setTrackUpstream := apr.GetValue(cli.TrackFlag)
   372  	if setTrackUpstream {
   373  		if trackVal == "inherit" {
   374  			return fmt.Errorf("--track='inherit' is not supported yet")
   375  		} else if trackVal == "direct" && apr.NArg() != 2 {
   376  			return InvalidArgErr
   377  		}
   378  
   379  		if apr.NArg() == 2 {
   380  			// branchName and startPt are already set
   381  			remoteName, remoteBranch = actions.ParseRemoteBranchName(startPt)
   382  			refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranch)
   383  			if err != nil {
   384  				return err
   385  			}
   386  		} else {
   387  			// if track option is defined with no value,
   388  			// the track value can either be starting point name OR branch name
   389  			startPt = trackVal
   390  			remoteName, remoteBranch = actions.ParseRemoteBranchName(startPt)
   391  			refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranch)
   392  			if err != nil {
   393  				branchName = trackVal
   394  				startPt = apr.Arg(0)
   395  				remoteName, remoteBranch = actions.ParseRemoteBranchName(startPt)
   396  				refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranch)
   397  				if err != nil {
   398  					return err
   399  				}
   400  			}
   401  		}
   402  	}
   403  
   404  	err = branch_control.CanCreateBranch(ctx, branchName)
   405  	if err != nil {
   406  		return err
   407  	}
   408  
   409  	err = actions.CreateBranchWithStartPt(ctx, dbData, branchName, startPt, apr.Contains(cli.ForceFlag), rsc)
   410  	if err != nil {
   411  		return err
   412  	}
   413  
   414  	if setTrackUpstream {
   415  		// at this point new branch is created
   416  		err = env.SetRemoteUpstreamForRefSpec(dbData.Rsw, refSpec, remoteName, ref.NewBranchRef(branchName))
   417  		if err != nil {
   418  			return err
   419  		}
   420  	}
   421  
   422  	return nil
   423  }
   424  
   425  func copyBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResults, rsc *doltdb.ReplicationStatusController) error {
   426  	if apr.NArg() != 2 {
   427  		return InvalidArgErr
   428  	}
   429  
   430  	srcBr := apr.Args[0]
   431  	if len(srcBr) == 0 {
   432  		return EmptyBranchNameErr
   433  	}
   434  
   435  	destBr := apr.Args[1]
   436  	if len(destBr) == 0 {
   437  		return EmptyBranchNameErr
   438  	}
   439  
   440  	force := apr.Contains(cli.ForceFlag)
   441  	return copyABranch(ctx, dbData, srcBr, destBr, force, rsc)
   442  }
   443  
   444  func copyABranch(ctx *sql.Context, dbData env.DbData, srcBr string, destBr string, force bool, rsc *doltdb.ReplicationStatusController) error {
   445  	if err := branch_control.CanCreateBranch(ctx, destBr); err != nil {
   446  		return err
   447  	}
   448  	// If force is enabled, we can overwrite the destination branch, so we require a permission check here, even if the
   449  	// destination branch doesn't exist. An unauthorized user could simply rerun the command without the force flag.
   450  	if force {
   451  		if err := branch_control.CanDeleteBranch(ctx, destBr); err != nil {
   452  			return err
   453  		}
   454  	}
   455  	err := actions.CopyBranchOnDB(ctx, dbData.Ddb, srcBr, destBr, force, rsc)
   456  	if err != nil {
   457  		if err == doltdb.ErrBranchNotFound {
   458  			return fmt.Errorf("fatal: A branch named '%s' not found", srcBr)
   459  		} else if err == actions.ErrAlreadyExists {
   460  			return fmt.Errorf("fatal: A branch named '%s' already exists.", destBr)
   461  		} else if err == doltdb.ErrInvBranchName {
   462  			return fmt.Errorf("fatal: '%s' is not a valid branch name.", destBr)
   463  		} else {
   464  			return fmt.Errorf("fatal: Unexpected error copying branch from '%s' to '%s'", srcBr, destBr)
   465  		}
   466  	}
   467  	err = branch_control.AddAdminForContext(ctx, destBr)
   468  	if err != nil {
   469  		return err
   470  	}
   471  
   472  	return nil
   473  }