github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_revert.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dprocedures 16 17 import ( 18 "fmt" 19 20 "github.com/dolthub/go-mysql-server/sql" 21 "github.com/dolthub/go-mysql-server/sql/expression" 22 23 "github.com/dolthub/dolt/go/cmd/dolt/cli" 24 "github.com/dolthub/dolt/go/libraries/doltcore/branch_control" 25 "github.com/dolthub/dolt/go/libraries/doltcore/diff" 26 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 27 "github.com/dolthub/dolt/go/libraries/doltcore/merge" 28 "github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo" 29 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" 30 ) 31 32 // doltRevert is the stored procedure version for the CLI command `dolt revert`. 33 func doltRevert(ctx *sql.Context, args ...string) (sql.RowIter, error) { 34 res, err := doDoltRevert(ctx, args) 35 if err != nil { 36 return nil, err 37 } 38 return rowToIter(int64(res)), nil 39 } 40 41 func doDoltRevert(ctx *sql.Context, args []string) (int, error) { 42 dbName := ctx.GetCurrentDatabase() 43 dSess := dsess.DSessFromSess(ctx.Session) 44 ddb, ok := dSess.GetDoltDB(ctx, dbName) 45 if !ok { 46 return 1, fmt.Errorf("dolt database could not be found") 47 } 48 if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { 49 return 1, err 50 } 51 52 roots, ok := dSess.GetRoots(ctx, dbName) 53 if !ok { 54 return 1, fmt.Errorf("Could not load session roots") 55 } 56 wsOnlyHasIgnoredTables, err := diff.WorkingSetContainsOnlyIgnoredTables(ctx, roots) 57 if err != nil { 58 return 1, err 59 } else if !wsOnlyHasIgnoredTables { 60 return 1, fmt.Errorf("You must commit any changes before using revert") 61 } 62 63 workingSet, err := dSess.WorkingSet(ctx, dbName) 64 if err != nil { 65 return 1, err 66 } 67 workingRoot := workingSet.WorkingRoot() 68 headCommit, err := dSess.GetHeadCommit(ctx, dbName) 69 if err != nil { 70 return 1, err 71 } 72 headRoot, err := headCommit.GetRootValue(ctx) 73 if err != nil { 74 return 1, err 75 } 76 headHash, err := headRoot.HashOf() 77 if err != nil { 78 return 1, err 79 } 80 workingHash, err := workingRoot.HashOf() 81 if err != nil { 82 return 1, err 83 } 84 85 headRef, err := dSess.CWBHeadRef(ctx, dbName) 86 if err != nil { 87 return 1, err 88 } 89 90 apr, err := cli.CreateRevertArgParser().Parse(args) 91 if err != nil { 92 return 1, err 93 } 94 95 commits := make([]*doltdb.Commit, apr.NArg()) 96 for i, revisionStr := range apr.Args { 97 commitSpec, err := doltdb.NewCommitSpec(revisionStr) 98 if err != nil { 99 return 1, err 100 } 101 optCmt, err := ddb.Resolve(ctx, commitSpec, headRef) 102 if err != nil { 103 return 1, err 104 } 105 commit, ok := optCmt.ToCommit() 106 if !ok { 107 return 1, doltdb.ErrGhostCommitEncountered 108 } 109 110 commits[i] = commit 111 } 112 113 dbState, ok, err := dSess.LookupDbState(ctx, dbName) 114 if err != nil { 115 return 1, err 116 } else if !ok { 117 return 1, fmt.Errorf("Could not load database %s", dbName) 118 } 119 120 workingRoot, revertMessage, err := merge.Revert(ctx, ddb, workingRoot, commits, dbState.EditOpts()) 121 if err != nil { 122 return 1, err 123 } 124 workingHash, err = workingRoot.HashOf() 125 if err != nil { 126 return 1, err 127 } 128 if !headHash.Equal(workingHash) { 129 err = dSess.SetWorkingRoot(ctx, dbName, workingRoot) 130 if err != nil { 131 return 1, err 132 } 133 stringType := typeinfo.StringDefaultType.ToSqlType() 134 135 expressions := []sql.Expression{expression.NewLiteral("-a", stringType), expression.NewLiteral("-m", stringType), expression.NewLiteral(revertMessage, stringType)} 136 137 author, hasAuthor := apr.GetValue(cli.AuthorParam) 138 if hasAuthor { 139 expressions = append(expressions, expression.NewLiteral("--author", stringType), expression.NewLiteral(author, stringType)) 140 } 141 142 commitArgs, err := getDoltArgs(ctx, nil, expressions) 143 if err != nil { 144 return 1, err 145 } 146 _, _, err = doDoltCommit(ctx, commitArgs) 147 if err != nil { 148 return 1, err 149 } 150 } 151 return 0, nil 152 }