github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/dfunctions/dolt_merge.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dfunctions 16 17 import ( 18 "errors" 19 "fmt" 20 "strings" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/expression" 24 25 "github.com/dolthub/dolt/go/cmd/dolt/cli" 26 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 27 "github.com/dolthub/dolt/go/libraries/doltcore/env" 28 "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" 29 "github.com/dolthub/dolt/go/libraries/doltcore/merge" 30 "github.com/dolthub/dolt/go/libraries/doltcore/sqle" 31 "github.com/dolthub/dolt/go/libraries/utils/argparser" 32 ) 33 34 const DoltMergeFuncName = "dolt_merge" 35 36 type DoltMergeFunc struct { 37 expression.NaryExpression 38 } 39 40 func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 41 dbName := ctx.GetCurrentDatabase() 42 43 if len(dbName) == 0 { 44 return 1, fmt.Errorf("Empty database name.") 45 } 46 47 sess := sqle.DSessFromSess(ctx.Session) 48 dbData, ok := sess.GetDbData(dbName) 49 50 if !ok { 51 return 1, fmt.Errorf("Could not load database %s", dbName) 52 } 53 54 ap := cli.CreateMergeArgParser() 55 args, err := getDoltArgs(ctx, row, d.Children()) 56 57 if err != nil { 58 return nil, err 59 } 60 61 apr, err := ap.Parse(args) 62 if err != nil { 63 return nil, err 64 } 65 66 if apr.ContainsAll(cli.SquashParam, cli.NoFFParam) { 67 return 1, fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together.\n", cli.SquashParam, cli.NoFFParam) 68 } 69 70 if apr.Contains(cli.AbortParam) { 71 if !dbData.Rsr.IsMergeActive() { 72 return 1, fmt.Errorf("fatal: There is no merge to abort") 73 } 74 75 err = abortMerge(ctx, dbData) 76 77 if err != nil { 78 return 1, err 79 } 80 81 return "Merge aborted", nil 82 } 83 84 // The first argument should be the branch name. 85 branchName := apr.Arg(0) 86 87 ddb, ok := sess.GetDoltDB(dbName) 88 if !ok { 89 return nil, sql.ErrDatabaseNotFound.New(dbName) 90 } 91 92 root, ok := sess.GetRoot(dbName) 93 if !ok { 94 return nil, sql.ErrDatabaseNotFound.New(dbName) 95 } 96 97 hasConflicts, err := root.HasConflicts(ctx) 98 if err != nil { 99 return 1, err 100 } 101 102 if hasConflicts { 103 return 1, doltdb.ErrUnresolvedConflicts 104 } 105 106 if dbData.Rsr.IsMergeActive() { 107 return 1, doltdb.ErrMergeActive 108 } 109 110 head, hh, headRoot, err := getHead(ctx, sess, dbName) 111 if err != nil { 112 return nil, err 113 } 114 115 err = checkForUncommittedChanges(root, headRoot) 116 if err != nil { 117 return nil, err 118 } 119 120 cm, cmh, err := getBranchCommit(ctx, branchName, ddb) 121 if err != nil { 122 return nil, err 123 } 124 125 // No need to write a merge commit, if the head can ffw to the commit coming from the branch. 126 canFF, err := head.CanFastForwardTo(ctx, cm) 127 if err != nil { 128 return nil, err 129 } 130 131 if canFF { 132 if apr.Contains(cli.NoFFParam) { 133 err = executeNoFFMerge(ctx, sess, apr, dbName, dbData, head, cm) 134 } else { 135 err = executeFFMerge(ctx, apr.Contains(cli.SquashParam), dbName, dbData, cm) 136 } 137 138 if err != nil { 139 return nil, err 140 } 141 return cmh.String(), err 142 } 143 144 err = executeMerge(ctx, apr.Contains(cli.SquashParam), head, cm, dbName, dbData) 145 if err != nil { 146 return nil, err 147 } 148 149 returnMsg := fmt.Sprintf("Updating %s..%s", cmh.String(), hh.String()) 150 151 return returnMsg, nil 152 } 153 154 func abortMerge(ctx *sql.Context, dbData env.DbData) error { 155 err := actions.CheckoutAllTables(ctx, dbData) 156 157 if err != nil { 158 return err 159 } 160 161 err = dbData.Rsw.AbortMerge() 162 if err != nil { 163 return err 164 } 165 166 hh, err := dbData.Rsr.CWBHeadHash(ctx) 167 if err != nil { 168 return err 169 } 170 171 return setHeadAndWorkingSessionRoot(ctx, hh.String()) 172 } 173 174 func executeMerge(ctx *sql.Context, squash bool, head, cm *doltdb.Commit, name string, dbData env.DbData) error { 175 mergeRoot, mergeStats, err := merge.MergeCommits(ctx, head, cm) 176 177 if err != nil { 178 switch err { 179 case doltdb.ErrUpToDate: 180 return errors.New("Already up to date.") 181 case merge.ErrFastForward: 182 panic("fast forward merge") 183 default: 184 return errors.New("Bad merge") 185 } 186 } 187 188 return mergeRootToWorking(ctx, squash, name, dbData, mergeRoot, cm, mergeStats) 189 } 190 191 func executeFFMerge(ctx *sql.Context, squash bool, dbName string, dbData env.DbData, cm2 *doltdb.Commit) error { 192 rv, err := cm2.GetRootValue() 193 194 if err != nil { 195 return errors.New("Failed to return root value.") 196 } 197 198 stagedHash, err := dbData.Ddb.WriteRootValue(ctx, rv) 199 200 if err != nil { 201 return err 202 } 203 204 workingHash := stagedHash 205 if !squash { 206 err = dbData.Ddb.FastForward(ctx, dbData.Rsr.CWBHeadRef(), cm2) 207 208 if err != nil { 209 return err 210 } 211 } 212 213 err = dbData.Rsw.SetWorkingHash(ctx, workingHash) 214 if err != nil { 215 return err 216 } 217 218 err = dbData.Rsw.SetStagedHash(ctx, stagedHash) 219 if err != nil { 220 return err 221 } 222 223 hh, err := dbData.Rsr.CWBHeadHash(ctx) 224 if err != nil { 225 return err 226 } 227 228 if squash { 229 return ctx.SetSessionVariable(ctx, sqle.WorkingKey(dbName), workingHash.String()) 230 } else { 231 return setHeadAndWorkingSessionRoot(ctx, hh.String()) 232 } 233 } 234 235 func executeNoFFMerge( 236 ctx *sql.Context, 237 dSess *sqle.DoltSession, 238 apr *argparser.ArgParseResults, 239 dbName string, 240 dbData env.DbData, 241 pr, cm2 *doltdb.Commit, 242 ) error { 243 mergedRoot, err := cm2.GetRootValue() 244 if err != nil { 245 return errors.New("Failed to return root value.") 246 } 247 248 err = mergeRootToWorking(ctx, false, dbName, dbData, mergedRoot, cm2, map[string]*merge.MergeStats{}) 249 if err != nil { 250 return err 251 } 252 253 msg, msgOk := apr.GetValue(cli.CommitMessageArg) 254 if !msgOk { 255 hh, err := pr.HashOf() 256 if err != nil { 257 return err 258 } 259 260 cmh, err := cm2.HashOf() 261 if err != nil { 262 return err 263 } 264 265 msg = fmt.Sprintf("SQL Generated commit merging %s into %s", hh.String(), cmh.String()) 266 } 267 268 var name, email string 269 if authorStr, ok := apr.GetValue(cli.AuthorParam); ok { 270 name, email, err = cli.ParseAuthor(authorStr) 271 if err != nil { 272 return err 273 } 274 } else { 275 name = dSess.Username 276 email = dSess.Email 277 } 278 279 // Specify the time if the date parameter is not. 280 t := ctx.QueryTime() 281 if commitTimeStr, ok := apr.GetValue(cli.DateParam); ok { 282 var err error 283 t, err = cli.ParseDate(commitTimeStr) 284 if err != nil { 285 return err 286 } 287 } 288 289 h, err := actions.CommitStaged(ctx, dbData, actions.CommitStagedProps{ 290 Message: msg, 291 Date: t, 292 AllowEmpty: apr.Contains(cli.AllowEmptyFlag), 293 CheckForeignKeys: !apr.Contains(cli.ForceFlag), 294 Name: name, 295 Email: email, 296 }) 297 298 if err != nil { 299 return err 300 } 301 302 return setHeadAndWorkingSessionRoot(ctx, h) 303 } 304 305 func mergeRootToWorking( 306 ctx *sql.Context, 307 squash bool, 308 dbName string, 309 dbData env.DbData, 310 mergedRoot *doltdb.RootValue, 311 cm2 *doltdb.Commit, 312 mergeStats map[string]*merge.MergeStats, 313 ) error { 314 h2, err := cm2.HashOf() 315 if err != nil { 316 return err 317 } 318 319 workingRoot := mergedRoot 320 if !squash { 321 err = dbData.Rsw.StartMerge(h2.String()) 322 323 if err != nil { 324 return err 325 } 326 } 327 328 workingHash, err := env.UpdateWorkingRoot(ctx, dbData.Ddb, dbData.Rsw, workingRoot) 329 if err != nil { 330 return err 331 } 332 333 hasConflicts := checkForConflicts(mergeStats) 334 335 if hasConflicts { 336 // If there are conflicts write them to the working root anyway too allow for merge resolution via the dolt_conflicts 337 // table. 338 err := ctx.SetSessionVariable(ctx, sqle.WorkingKey(dbName), workingHash.String()) 339 if err != nil { 340 return err 341 } 342 343 return doltdb.ErrUnresolvedConflicts 344 } 345 346 _, err = env.UpdateStagedRoot(ctx, dbData.Ddb, dbData.Rsw, workingRoot) 347 if err != nil { 348 return err 349 } 350 351 return ctx.SetSessionVariable(ctx, sqle.WorkingKey(dbName), workingHash.String()) 352 } 353 354 func checkForConflicts(tblToStats map[string]*merge.MergeStats) bool { 355 for _, stats := range tblToStats { 356 if stats.Operation == merge.TableModified && stats.Conflicts > 0 { 357 return true 358 } 359 } 360 361 return false 362 } 363 364 func (d DoltMergeFunc) String() string { 365 childrenStrings := make([]string, len(d.Children())) 366 367 for i, child := range d.Children() { 368 childrenStrings[i] = child.String() 369 } 370 371 return fmt.Sprintf("DOLT_MERGE(%s)", strings.Join(childrenStrings, ",")) 372 } 373 374 func (d DoltMergeFunc) Type() sql.Type { 375 return sql.Text 376 } 377 378 func (d DoltMergeFunc) WithChildren(ctx *sql.Context, children ...sql.Expression) (sql.Expression, error) { 379 return NewDoltMergeFunc(ctx, children...) 380 } 381 382 func NewDoltMergeFunc(ctx *sql.Context, args ...sql.Expression) (sql.Expression, error) { 383 return &DoltMergeFunc{expression.NaryExpression{ChildExpressions: args}}, nil 384 }