github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/dfunctions/dolt_checkout.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dfunctions 16 17 import ( 18 "errors" 19 "fmt" 20 "strings" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/expression" 24 25 "github.com/dolthub/dolt/go/cmd/dolt/cli" 26 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 27 "github.com/dolthub/dolt/go/libraries/doltcore/env" 28 "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" 29 "github.com/dolthub/dolt/go/libraries/doltcore/sqle" 30 ) 31 32 const DoltCheckoutFuncName = "dolt_checkout" 33 34 var ErrEmptyBranchName = errors.New("error: cannot checkout empty string") 35 36 type DoltCheckoutFunc struct { 37 expression.NaryExpression 38 } 39 40 func (d DoltCheckoutFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 41 dbName := ctx.GetCurrentDatabase() 42 43 if len(dbName) == 0 { 44 return 1, fmt.Errorf("Empty database name.") 45 } 46 47 dSess := sqle.DSessFromSess(ctx.Session) 48 dbData, ok := dSess.GetDbData(dbName) 49 50 if !ok { 51 return 1, fmt.Errorf("Could not load database %s", dbName) 52 } 53 54 ap := cli.CreateCheckoutArgParser() 55 args, err := getDoltArgs(ctx, row, d.Children()) 56 57 if err != nil { 58 return 1, err 59 } 60 61 apr, err := ap.Parse(args) 62 if err != nil { 63 return 1, err 64 } 65 66 if (apr.Contains(cli.CheckoutCoBranch) && apr.NArg() > 1) || (!apr.Contains(cli.CheckoutCoBranch) && apr.NArg() == 0) { 67 return 1, errors.New("Improper usage.") 68 } 69 70 // Checking out new branch. 71 if newBranch, newBranchOk := apr.GetValue(cli.CheckoutCoBranch); newBranchOk { 72 if len(newBranch) == 0 { 73 err = errors.New("error: cannot checkout empty string") 74 } else { 75 err = checkoutNewBranch(ctx, dbData, newBranch, "") 76 } 77 78 if err != nil { 79 return 1, err 80 } 81 82 return 0, nil 83 } 84 85 name := apr.Arg(0) 86 87 // Check if user wants to checkout branch. 88 if isBranch, err := actions.IsBranch(ctx, dbData.Ddb, name); err != nil { 89 return 1, err 90 } else if isBranch { 91 err = checkoutBranch(ctx, dbData, name) 92 if err != nil { 93 return 1, err 94 } 95 return 0, nil 96 } 97 98 // Check if user want to checkout table or docs. 99 tbls, docs, err := actions.GetTablesOrDocs(dbData.Drw, args) 100 if err != nil { 101 return 1, errors.New("error: unable to parse arguments.") 102 } 103 104 if len(docs) > 0 { 105 return 1, errors.New("error: docs not supported in sql mode") 106 } 107 108 err = checkoutTables(ctx, dbData, tbls) 109 110 if err != nil && apr.NArg() == 1 { 111 err = checkoutRemoteBranch(ctx, dbData, name) 112 } 113 114 if err != nil { 115 return 1, err 116 } 117 118 return 0, nil 119 } 120 121 func checkoutRemoteBranch(ctx *sql.Context, dbData env.DbData, branchName string) error { 122 if len(branchName) == 0 { 123 return ErrEmptyBranchName 124 } 125 126 if ref, refExists, err := actions.GetRemoteBranchRef(ctx, dbData.Ddb, branchName); err != nil { 127 return errors.New("fatal: unable to read from data repository") 128 } else if refExists { 129 return checkoutNewBranch(ctx, dbData, branchName, ref.String()) 130 } else { 131 return fmt.Errorf("error: could not find %s", branchName) 132 } 133 } 134 135 func checkoutNewBranch(ctx *sql.Context, dbData env.DbData, branchName, startPt string) error { 136 if len(branchName) == 0 { 137 return ErrEmptyBranchName 138 } 139 140 if startPt == "" { 141 startPt = "head" 142 } 143 144 err := actions.CreateBranchWithStartPt(ctx, dbData, branchName, startPt, false) 145 if err != nil { 146 return err 147 } 148 149 return checkoutBranch(ctx, dbData, branchName) 150 } 151 152 func checkoutBranch(ctx *sql.Context, dbData env.DbData, branchName string) error { 153 if len(branchName) == 0 { 154 return ErrEmptyBranchName 155 } 156 157 err := actions.CheckoutBranchWithoutDocs(ctx, dbData, branchName) 158 159 if err != nil { 160 if err == doltdb.ErrBranchNotFound { 161 return fmt.Errorf("fatal: Branch '%s' not found.", branchName) 162 } else if doltdb.IsRootValUnreachable(err) { 163 rt := doltdb.GetUnreachableRootType(err) 164 return fmt.Errorf("error: unable to read the %s", rt.String()) 165 } else if actions.IsCheckoutWouldOverwrite(err) { 166 tbls := actions.CheckoutWouldOverwriteTables(err) 167 msg := "error: Your local changes to the following tables would be overwritten by checkout: \n" 168 for _, tbl := range tbls { 169 msg = msg + tbl + "\n" 170 } 171 return errors.New(msg) 172 } else if err == doltdb.ErrAlreadyOnBranch { 173 return nil // No need to return an error if on the same branch 174 } else { 175 return fmt.Errorf("fatal: Unexpected error checking out branch '%s'", branchName) 176 } 177 } 178 179 return updateHeadAndWorkingSessionVars(ctx, dbData) 180 } 181 182 func checkoutTables(ctx *sql.Context, dbData env.DbData, tables []string) error { 183 err := actions.CheckoutTables(ctx, dbData, tables) 184 185 if err != nil { 186 if doltdb.IsRootValUnreachable(err) { 187 rt := doltdb.GetUnreachableRootType(err) 188 return fmt.Errorf("error: unable to read the %s", rt.String()) 189 } else if actions.IsTblNotExist(err) { 190 return fmt.Errorf("error: given tables do not exist") 191 } else { 192 return fmt.Errorf("fatal: Unexpected error checking out tables") 193 } 194 } 195 196 return updateHeadAndWorkingSessionVars(ctx, dbData) 197 } 198 199 // updateHeadAndWorkingSessionVars explicitly sets the head and working hash. 200 func updateHeadAndWorkingSessionVars(ctx *sql.Context, dbData env.DbData) error { 201 headHash, err := dbData.Rsr.CWBHeadHash(ctx) 202 if err != nil { 203 return err 204 } 205 hs := headHash.String() 206 207 hasWorkingChanges := hasWorkingSetChanges(dbData.Rsr) 208 hasStagedChanges, err := hasStagedSetChanges(ctx, dbData.Ddb, dbData.Rsr) 209 210 if err != nil { 211 return err 212 } 213 214 workingHash := dbData.Rsr.WorkingHash().String() 215 216 // This will update the session table editor's root and clear its cache. 217 if !hasStagedChanges && !hasWorkingChanges { 218 return setHeadAndWorkingSessionRoot(ctx, hs) 219 } 220 221 err = setSessionRootExplicit(ctx, hs, sqle.HeadKeySuffix) 222 if err != nil { 223 return err 224 } 225 226 return setSessionRootExplicit(ctx, workingHash, sqle.WorkingKeySuffix) 227 } 228 229 func (d DoltCheckoutFunc) String() string { 230 childrenStrings := make([]string, len(d.Children())) 231 232 for i, child := range d.Children() { 233 childrenStrings[i] = child.String() 234 } 235 236 return fmt.Sprintf("DOLT_CHECKOUT(%s)", strings.Join(childrenStrings, ",")) 237 } 238 239 func (d DoltCheckoutFunc) Type() sql.Type { 240 return sql.Int8 241 } 242 243 func (d DoltCheckoutFunc) WithChildren(ctx *sql.Context, children ...sql.Expression) (sql.Expression, error) { 244 return NewDoltCheckoutFunc(ctx, children...) 245 } 246 247 func NewDoltCheckoutFunc(ctx *sql.Context, args ...sql.Expression) (sql.Expression, error) { 248 return &DoltCheckoutFunc{expression.NaryExpression{ChildExpressions: args}}, nil 249 }