github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/dfunctions/merge.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dfunctions
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  
    24  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/merge"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
    28  	"github.com/dolthub/dolt/go/store/hash"
    29  )
    30  
    31  const MergeFuncName = "merge"
    32  
    33  type MergeFunc struct {
    34  	children []sql.Expression
    35  }
    36  
    37  // NewMergeFunc creates a new MergeFunc expression.
    38  func NewMergeFunc(ctx *sql.Context, args ...sql.Expression) (sql.Expression, error) {
    39  	return &MergeFunc{children: args}, nil
    40  }
    41  
    42  // Eval implements the Expression interface.
    43  func (cf *MergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    44  	sess := sqle.DSessFromSess(ctx.Session)
    45  
    46  	// TODO: Move to a separate MERGE argparser.
    47  	ap := cli.CreateCommitArgParser()
    48  	args, err := getDoltArgs(ctx, row, cf.Children())
    49  
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	apr, err := ap.Parse(args)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  
    59  	// The fist argument should be the branch name.
    60  	branchName := apr.Arg(0)
    61  
    62  	var name, email string
    63  	if authorStr, ok := apr.GetValue(cli.AuthorParam); ok {
    64  		name, email, err = cli.ParseAuthor(authorStr)
    65  		if err != nil {
    66  			return nil, err
    67  		}
    68  	} else {
    69  		name = sess.Username
    70  		email = sess.Email
    71  	}
    72  
    73  	dbName := sess.GetCurrentDatabase()
    74  	ddb, ok := sess.GetDoltDB(dbName)
    75  	if !ok {
    76  		return nil, sql.ErrDatabaseNotFound.New(dbName)
    77  	}
    78  
    79  	root, ok := sess.GetRoot(dbName)
    80  	if !ok {
    81  		return nil, sql.ErrDatabaseNotFound.New(dbName)
    82  	}
    83  
    84  	head, hh, headRoot, err := getHead(ctx, sess, dbName)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	err = checkForUncommittedChanges(root, headRoot)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	cm, cmh, err := getBranchCommit(ctx, branchName, ddb)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	// No need to write a merge commit, if the head can ffw to the commit coming from the branch.
   100  	canFF, err := head.CanFastForwardTo(ctx, cm)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	if canFF {
   106  		return cmh.String(), nil
   107  	}
   108  
   109  	mergeRoot, _, err := merge.MergeCommits(ctx, head, cm)
   110  
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	h, err := ddb.WriteRootValue(ctx, mergeRoot)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	commitMessage := fmt.Sprintf("SQL Generated commit merging %s into %s", hh.String(), cmh.String())
   121  	meta, err := doltdb.NewCommitMeta(name, email, commitMessage)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	mergeCommit, err := ddb.CommitDanglingWithParentCommits(ctx, h, []*doltdb.Commit{head, cm}, meta)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	h, err = mergeCommit.HashOf()
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	return h.String(), nil
   137  }
   138  
   139  func checkForUncommittedChanges(root *doltdb.RootValue, headRoot *doltdb.RootValue) error {
   140  	rh, err := root.HashOf()
   141  
   142  	if err != nil {
   143  		return err
   144  	}
   145  
   146  	hrh, err := headRoot.HashOf()
   147  
   148  	if err != nil {
   149  		return err
   150  	}
   151  
   152  	if rh != hrh {
   153  		return errors.New("cannot merge with uncommitted changes")
   154  	}
   155  
   156  	return nil
   157  }
   158  
   159  func getBranchCommit(ctx *sql.Context, val interface{}, ddb *doltdb.DoltDB) (*doltdb.Commit, hash.Hash, error) {
   160  	paramStr, ok := val.(string)
   161  
   162  	if !ok {
   163  		return nil, hash.Hash{}, errors.New("branch name is not a string")
   164  	}
   165  
   166  	branchRef, err := getBranchInsensitive(ctx, paramStr, ddb)
   167  
   168  	if err != nil {
   169  		return nil, hash.Hash{}, err
   170  	}
   171  
   172  	cm, err := ddb.ResolveCommitRef(ctx, branchRef)
   173  
   174  	if err != nil {
   175  		return nil, hash.Hash{}, err
   176  	}
   177  
   178  	cmh, err := cm.HashOf()
   179  
   180  	if err != nil {
   181  		return nil, hash.Hash{}, err
   182  	}
   183  
   184  	return cm, cmh, nil
   185  }
   186  
   187  func getHead(ctx *sql.Context, sess *sqle.DoltSession, dbName string) (*doltdb.Commit, hash.Hash, *doltdb.RootValue, error) {
   188  	head, hh, err := sess.GetHeadCommit(ctx, dbName)
   189  	if err != nil {
   190  		return nil, hash.Hash{}, nil, err
   191  	}
   192  
   193  	headRoot, err := head.GetRootValue()
   194  	if err != nil {
   195  		return nil, hash.Hash{}, nil, err
   196  	}
   197  
   198  	return head, hh, headRoot, nil
   199  }
   200  
   201  // String implements the Stringer interface.
   202  func (cf *MergeFunc) String() string {
   203  	childrenStrings := make([]string, len(cf.children))
   204  
   205  	for i, child := range cf.children {
   206  		childrenStrings[i] = child.String()
   207  	}
   208  	return fmt.Sprintf("Merge(%s)", strings.Join(childrenStrings, ","))
   209  }
   210  
   211  // IsNullable implements the Expression interface.
   212  func (cf *MergeFunc) IsNullable() bool {
   213  	return false
   214  }
   215  
   216  func (cf *MergeFunc) Resolved() bool {
   217  	for _, child := range cf.Children() {
   218  		if !child.Resolved() {
   219  			return false
   220  		}
   221  	}
   222  	return true
   223  }
   224  
   225  func (cf *MergeFunc) Children() []sql.Expression {
   226  	return cf.children
   227  }
   228  
   229  // WithChildren implements the Expression interface.
   230  func (cf *MergeFunc) WithChildren(ctx *sql.Context, children ...sql.Expression) (sql.Expression, error) {
   231  	return NewMergeFunc(ctx, children...)
   232  }
   233  
   234  // Type implements the Expression interface.
   235  func (cf *MergeFunc) Type() sql.Type {
   236  	return sql.Text
   237  }