github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_branch.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dprocedures 16 17 import ( 18 "errors" 19 "fmt" 20 "strings" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 24 "github.com/dolthub/dolt/go/cmd/dolt/cli" 25 "github.com/dolthub/dolt/go/libraries/doltcore/branch_control" 26 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 27 "github.com/dolthub/dolt/go/libraries/doltcore/env" 28 "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" 29 "github.com/dolthub/dolt/go/libraries/doltcore/ref" 30 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" 31 "github.com/dolthub/dolt/go/libraries/doltcore/sqlserver" 32 "github.com/dolthub/dolt/go/libraries/utils/argparser" 33 "github.com/dolthub/dolt/go/libraries/utils/filesys" 34 ) 35 36 var ( 37 EmptyBranchNameErr = errors.New("error: cannot branch empty string") 38 InvalidArgErr = errors.New("error: invalid usage") 39 ) 40 41 // doltBranch is the stored procedure version for the CLI command `dolt branch`. 42 func doltBranch(ctx *sql.Context, args ...string) (sql.RowIter, error) { 43 res, err := doDoltBranch(ctx, args) 44 if err != nil { 45 return nil, err 46 } 47 return rowToIter(int64(res)), nil 48 } 49 50 func doDoltBranch(ctx *sql.Context, args []string) (int, error) { 51 dbName := ctx.GetCurrentDatabase() 52 53 if len(dbName) == 0 { 54 return 1, fmt.Errorf("Empty database name.") 55 } 56 57 // CreateBranchArgParser has the common flags for the command line and the stored procedure. 58 // The stored procedure doesn't support all actions, so we have a shorter description for -r. 59 ap := cli.CreateBranchArgParser() 60 ap.SupportsFlag(cli.RemoteParam, "r", "Delete a remote tracking branch.") 61 apr, err := ap.Parse(args) 62 if err != nil { 63 return 1, err 64 } 65 66 dSess := dsess.DSessFromSess(ctx.Session) 67 dbData, ok := dSess.GetDbData(ctx, dbName) 68 if !ok { 69 return 1, fmt.Errorf("Could not load database %s", dbName) 70 } 71 72 var rsc doltdb.ReplicationStatusController 73 74 switch { 75 case apr.Contains(cli.CopyFlag): 76 err = copyBranch(ctx, dbData, apr, &rsc) 77 case apr.Contains(cli.MoveFlag): 78 err = renameBranch(ctx, dbData, apr, dSess, dbName, &rsc) 79 case apr.Contains(cli.DeleteFlag), apr.Contains(cli.DeleteForceFlag): 80 err = deleteBranches(ctx, dbData, apr, dSess, dbName, &rsc) 81 default: 82 err = createNewBranch(ctx, dbData, apr, &rsc) 83 } 84 85 if err != nil { 86 return 1, err 87 } else { 88 return 0, commitTransaction(ctx, dSess, &rsc) 89 } 90 } 91 92 func commitTransaction(ctx *sql.Context, dSess *dsess.DoltSession, rsc *doltdb.ReplicationStatusController) error { 93 currentTx := ctx.GetTransaction() 94 95 err := dSess.CommitTransaction(ctx, currentTx) 96 if err != nil { 97 return err 98 } 99 newTx, err := dSess.StartTransaction(ctx, sql.ReadWrite) 100 if err != nil { 101 return err 102 } 103 ctx.SetTransaction(newTx) 104 105 if rsc != nil { 106 dsess.WaitForReplicationController(ctx, *rsc) 107 } 108 109 return nil 110 } 111 112 // renameBranch takes DoltSession and database name to try accessing file system for dolt database. 113 // If the oldBranch being renamed is the current branch on CLI, then RepoState head will be updated with the newBranch ref. 114 func renameBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResults, sess *dsess.DoltSession, dbName string, rsc *doltdb.ReplicationStatusController) error { 115 if apr.NArg() != 2 { 116 return InvalidArgErr 117 } 118 oldBranchName, newBranchName := apr.Arg(0), apr.Arg(1) 119 if oldBranchName == "" || newBranchName == "" { 120 return EmptyBranchNameErr 121 } 122 if err := branch_control.CanDeleteBranch(ctx, oldBranchName); err != nil { 123 return err 124 } 125 if err := branch_control.CanCreateBranch(ctx, newBranchName); err != nil { 126 return err 127 } 128 force := apr.Contains(cli.ForceFlag) 129 130 if !force { 131 err := validateBranchNotActiveInAnySession(ctx, oldBranchName) 132 if err != nil { 133 return err 134 } 135 var headOnCLI string 136 fs, err := sess.Provider().FileSystemForDatabase(dbName) 137 if err == nil { 138 if repoState, err := env.LoadRepoState(fs); err == nil { 139 headOnCLI = repoState.Head.Ref.GetPath() 140 } 141 } 142 if headOnCLI == oldBranchName && sqlserver.RunningInServerMode() && !shouldAllowDefaultBranchDeletion(ctx) { 143 return fmt.Errorf("unable to rename branch '%s', because it is the default branch for "+ 144 "database '%s'; this can by changed on the command line, by stopping the sql-server, "+ 145 "running `dolt checkout <another_branch> and restarting the sql-server", oldBranchName, dbName) 146 } 147 148 } else if err := branch_control.CanDeleteBranch(ctx, newBranchName); err != nil { 149 // If force is enabled, we can overwrite the destination branch, so we require a permission check here, even if the 150 // destination branch doesn't exist. An unauthorized user could simply rerun the command without the force flag. 151 return err 152 } 153 154 headRef, err := dbData.Rsr.CWBHeadRef() 155 if err != nil { 156 return err 157 } 158 activeSessionBranch := headRef.GetPath() 159 160 err = actions.RenameBranch(ctx, dbData, oldBranchName, newBranchName, sess.Provider(), force, rsc) 161 if err != nil { 162 return err 163 } 164 err = branch_control.AddAdminForContext(ctx, newBranchName) 165 if err != nil { 166 return err 167 } 168 169 // The current branch on CLI can be deleted as user can be on different branch on SQL and delete it from SQL session. 170 // To update current head info on RepoState, we need DoltEnv to load CLI environment. 171 if fs, err := sess.Provider().FileSystemForDatabase(dbName); err == nil { 172 if repoState, err := env.LoadRepoState(fs); err == nil { 173 if repoState.Head.Ref.GetPath() == oldBranchName { 174 repoState.Head.Ref = ref.NewBranchRef(newBranchName) 175 repoState.Save(fs) 176 } 177 } 178 } 179 180 err = sess.RenameBranchState(ctx, dbName, oldBranchName, newBranchName) 181 if err != nil { 182 return err 183 } 184 185 // If the active branch of the SQL session was renamed, switch to the new branch. 186 if oldBranchName == activeSessionBranch { 187 wsRef, err := ref.WorkingSetRefForHead(ref.NewBranchRef(newBranchName)) 188 if err != nil { 189 return err 190 } 191 192 err = sess.SwitchWorkingSet(ctx, dbName, wsRef) 193 if err != nil { 194 return err 195 } 196 } 197 198 return nil 199 } 200 201 // deleteBranches takes DoltSession and database name to try accessing file system for dolt database. 202 // If the database is not session state db and the branch being deleted is the current branch on CLI, it will update 203 // the RepoState to set head as empty branchRef. 204 func deleteBranches(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResults, sess *dsess.DoltSession, dbName string, rsc *doltdb.ReplicationStatusController) error { 205 if apr.NArg() == 0 { 206 return InvalidArgErr 207 } 208 209 currBase, currBranch := dsess.SplitRevisionDbName(ctx.GetCurrentDatabase()) 210 211 // The current branch on CLI can be deleted as user can be on different branch on SQL and delete it from SQL session. 212 // To update current head info on RepoState, we need DoltEnv to load CLI environment. 213 var headOnCLI string 214 fs, err := sess.Provider().FileSystemForDatabase(dbName) 215 if err == nil { 216 if repoState, err := env.LoadRepoState(fs); err == nil { 217 headOnCLI = repoState.Head.Ref.GetPath() 218 } 219 } 220 221 // Verify that we can delete all branches before continuing 222 for _, branchName := range apr.Args { 223 if err = branch_control.CanDeleteBranch(ctx, branchName); err != nil { 224 return err 225 } 226 } 227 228 dSess := dsess.DSessFromSess(ctx.Session) 229 for _, branchName := range apr.Args { 230 if len(branchName) == 0 { 231 return EmptyBranchNameErr 232 } 233 234 force := apr.Contains(cli.DeleteForceFlag) || apr.Contains(cli.ForceFlag) 235 if !force { 236 err = validateBranchNotActiveInAnySession(ctx, branchName) 237 if err != nil { 238 return err 239 } 240 } 241 242 // If we deleted the branch this client is connected to, change the current branch to the default 243 // TODO: this would be nice to do for every other session (or maybe invalidate sessions on this branch) 244 if strings.ToLower(currBranch) == strings.ToLower(branchName) { 245 ctx.SetCurrentDatabase(currBase) 246 } 247 248 if headOnCLI == branchName && sqlserver.RunningInServerMode() && !shouldAllowDefaultBranchDeletion(ctx) { 249 return fmt.Errorf("unable to delete branch '%s', because it is the default branch for "+ 250 "database '%s'; this can by changed on the command line, by stopping the sql-server, "+ 251 "running `dolt checkout <another_branch> and restarting the sql-server", branchName, dbName) 252 } 253 254 remote := apr.Contains(cli.RemoteParam) 255 256 err = actions.DeleteBranch(ctx, dbData, branchName, actions.DeleteOptions{ 257 Force: force, 258 Remote: remote, 259 }, dSess.Provider(), rsc) 260 if err != nil { 261 return err 262 } 263 264 // If the session has this branch checked out, we need to change that to the default head 265 headRef, err := dSess.CWBHeadRef(ctx, currBase) 266 if err != nil { 267 return err 268 } 269 270 if headRef == ref.NewBranchRef(branchName) { 271 err = dSess.RemoveBranchState(ctx, currBase, branchName) 272 if err != nil { 273 return err 274 } 275 } 276 } 277 278 return nil 279 } 280 281 // shouldAllowDefaultBranchDeletion returns true if the default branch deletion check should be 282 // bypassed for testing. This should only ever be true for tests that need to invalidate a databases 283 // default branch to test recovery from a bad state. We determine if the check should be bypassed by 284 // looking for the presence of an undocumented dolt user var, dolt_allow_default_branch_deletion. 285 func shouldAllowDefaultBranchDeletion(ctx *sql.Context) bool { 286 _, userVar, _ := ctx.Session.GetUserVariable(ctx, "dolt_allow_default_branch_deletion") 287 return userVar != nil 288 } 289 290 // validateBranchNotActiveInAnySessions returns an error if the specified branch is currently 291 // selected as the active branch for any active server sessions. 292 func validateBranchNotActiveInAnySession(ctx *sql.Context, branchName string) error { 293 currentDbName := ctx.GetCurrentDatabase() 294 currentDbName, _ = dsess.SplitRevisionDbName(currentDbName) 295 if currentDbName == "" { 296 return nil 297 } 298 299 if sqlserver.RunningInServerMode() == false { 300 return nil 301 } 302 303 runningServer := sqlserver.GetRunningServer() 304 if runningServer == nil { 305 return nil 306 } 307 sessionManager := runningServer.SessionManager() 308 branchRef := ref.NewBranchRef(branchName) 309 310 return sessionManager.Iter(func(session sql.Session) (bool, error) { 311 if session.ID() == ctx.Session.ID() { 312 return false, nil 313 } 314 315 sess, ok := session.(*dsess.DoltSession) 316 if !ok { 317 return false, fmt.Errorf("unexpected session type: %T", session) 318 } 319 320 sessionDbName := sess.Session.GetCurrentDatabase() 321 baseName, _ := dsess.SplitRevisionDbName(sessionDbName) 322 if len(baseName) == 0 || baseName != currentDbName { 323 return false, nil 324 } 325 326 activeBranchRef, err := sess.CWBHeadRef(ctx, sessionDbName) 327 if err != nil { 328 // The above will throw an error if the current DB doesn't have a head ref, in which case we don't need to 329 // consider it 330 return false, nil 331 } 332 333 if ref.Equals(branchRef, activeBranchRef) { 334 return false, fmt.Errorf("unsafe to delete or rename branches in use in other sessions; " + 335 "use --force to force the change") 336 } 337 338 return false, nil 339 }) 340 } 341 342 // TODO: the config should be available via the context, it's unnecessary to do an env.Load here and this should be removed 343 func loadConfig(ctx *sql.Context) *env.DoltCliConfig { 344 // When executing branch actions from SQL, we don't have access to a DoltEnv like we do from 345 // within the CLI. We can fake it here enough to get a DoltCliConfig, but we can't rely on the 346 // DoltEnv because tests and production will run with different settings (e.g. in-mem versus file). 347 dEnv := env.Load(ctx, env.GetCurrentUserHomeDir, filesys.LocalFS, doltdb.LocalDirDoltDB, "") 348 return dEnv.Config 349 } 350 351 func createNewBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResults, rsc *doltdb.ReplicationStatusController) error { 352 if apr.NArg() == 0 || apr.NArg() > 2 { 353 return InvalidArgErr 354 } 355 356 var branchName = apr.Arg(0) 357 var startPt = "HEAD" 358 if len(branchName) == 0 { 359 return EmptyBranchNameErr 360 } 361 if apr.NArg() == 2 { 362 startPt = apr.Arg(1) 363 if len(startPt) == 0 { 364 return InvalidArgErr 365 } 366 } 367 368 var remoteName, remoteBranch string 369 var refSpec ref.RefSpec 370 var err error 371 trackVal, setTrackUpstream := apr.GetValue(cli.TrackFlag) 372 if setTrackUpstream { 373 if trackVal == "inherit" { 374 return fmt.Errorf("--track='inherit' is not supported yet") 375 } else if trackVal == "direct" && apr.NArg() != 2 { 376 return InvalidArgErr 377 } 378 379 if apr.NArg() == 2 { 380 // branchName and startPt are already set 381 remoteName, remoteBranch = actions.ParseRemoteBranchName(startPt) 382 refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranch) 383 if err != nil { 384 return err 385 } 386 } else { 387 // if track option is defined with no value, 388 // the track value can either be starting point name OR branch name 389 startPt = trackVal 390 remoteName, remoteBranch = actions.ParseRemoteBranchName(startPt) 391 refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranch) 392 if err != nil { 393 branchName = trackVal 394 startPt = apr.Arg(0) 395 remoteName, remoteBranch = actions.ParseRemoteBranchName(startPt) 396 refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranch) 397 if err != nil { 398 return err 399 } 400 } 401 } 402 } 403 404 err = branch_control.CanCreateBranch(ctx, branchName) 405 if err != nil { 406 return err 407 } 408 409 err = actions.CreateBranchWithStartPt(ctx, dbData, branchName, startPt, apr.Contains(cli.ForceFlag), rsc) 410 if err != nil { 411 return err 412 } 413 414 if setTrackUpstream { 415 // at this point new branch is created 416 err = env.SetRemoteUpstreamForRefSpec(dbData.Rsw, refSpec, remoteName, ref.NewBranchRef(branchName)) 417 if err != nil { 418 return err 419 } 420 } 421 422 return nil 423 } 424 425 func copyBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResults, rsc *doltdb.ReplicationStatusController) error { 426 if apr.NArg() != 2 { 427 return InvalidArgErr 428 } 429 430 srcBr := apr.Args[0] 431 if len(srcBr) == 0 { 432 return EmptyBranchNameErr 433 } 434 435 destBr := apr.Args[1] 436 if len(destBr) == 0 { 437 return EmptyBranchNameErr 438 } 439 440 force := apr.Contains(cli.ForceFlag) 441 return copyABranch(ctx, dbData, srcBr, destBr, force, rsc) 442 } 443 444 func copyABranch(ctx *sql.Context, dbData env.DbData, srcBr string, destBr string, force bool, rsc *doltdb.ReplicationStatusController) error { 445 if err := branch_control.CanCreateBranch(ctx, destBr); err != nil { 446 return err 447 } 448 // If force is enabled, we can overwrite the destination branch, so we require a permission check here, even if the 449 // destination branch doesn't exist. An unauthorized user could simply rerun the command without the force flag. 450 if force { 451 if err := branch_control.CanDeleteBranch(ctx, destBr); err != nil { 452 return err 453 } 454 } 455 err := actions.CopyBranchOnDB(ctx, dbData.Ddb, srcBr, destBr, force, rsc) 456 if err != nil { 457 if err == doltdb.ErrBranchNotFound { 458 return fmt.Errorf("fatal: A branch named '%s' not found", srcBr) 459 } else if err == actions.ErrAlreadyExists { 460 return fmt.Errorf("fatal: A branch named '%s' already exists.", destBr) 461 } else if err == doltdb.ErrInvBranchName { 462 return fmt.Errorf("fatal: '%s' is not a valid branch name.", destBr) 463 } else { 464 return fmt.Errorf("fatal: Unexpected error copying branch from '%s' to '%s'", srcBr, destBr) 465 } 466 } 467 err = branch_control.AddAdminForContext(ctx, destBr) 468 if err != nil { 469 return err 470 } 471 472 return nil 473 }