github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/dfunctions/dolt_checkout.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dfunctions
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/expression"
    24  
    25  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
    30  )
    31  
    32  const DoltCheckoutFuncName = "dolt_checkout"
    33  
    34  var ErrEmptyBranchName = errors.New("error: cannot checkout empty string")
    35  
    36  type DoltCheckoutFunc struct {
    37  	expression.NaryExpression
    38  }
    39  
    40  func (d DoltCheckoutFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    41  	dbName := ctx.GetCurrentDatabase()
    42  
    43  	if len(dbName) == 0 {
    44  		return 1, fmt.Errorf("Empty database name.")
    45  	}
    46  
    47  	dSess := sqle.DSessFromSess(ctx.Session)
    48  	dbData, ok := dSess.GetDbData(dbName)
    49  
    50  	if !ok {
    51  		return 1, fmt.Errorf("Could not load database %s", dbName)
    52  	}
    53  
    54  	ap := cli.CreateCheckoutArgParser()
    55  	args, err := getDoltArgs(ctx, row, d.Children())
    56  
    57  	if err != nil {
    58  		return 1, err
    59  	}
    60  
    61  	apr, err := ap.Parse(args)
    62  	if err != nil {
    63  		return 1, err
    64  	}
    65  
    66  	if (apr.Contains(cli.CheckoutCoBranch) && apr.NArg() > 1) || (!apr.Contains(cli.CheckoutCoBranch) && apr.NArg() == 0) {
    67  		return 1, errors.New("Improper usage.")
    68  	}
    69  
    70  	// Checking out new branch.
    71  	if newBranch, newBranchOk := apr.GetValue(cli.CheckoutCoBranch); newBranchOk {
    72  		if len(newBranch) == 0 {
    73  			err = errors.New("error: cannot checkout empty string")
    74  		} else {
    75  			err = checkoutNewBranch(ctx, dbData, newBranch, "")
    76  		}
    77  
    78  		if err != nil {
    79  			return 1, err
    80  		}
    81  
    82  		return 0, nil
    83  	}
    84  
    85  	name := apr.Arg(0)
    86  
    87  	// Check if user wants to checkout branch.
    88  	if isBranch, err := actions.IsBranch(ctx, dbData.Ddb, name); err != nil {
    89  		return 1, err
    90  	} else if isBranch {
    91  		err = checkoutBranch(ctx, dbData, name)
    92  		if err != nil {
    93  			return 1, err
    94  		}
    95  		return 0, nil
    96  	}
    97  
    98  	// Check if user want to checkout table or docs.
    99  	tbls, docs, err := actions.GetTablesOrDocs(dbData.Drw, args)
   100  	if err != nil {
   101  		return 1, errors.New("error: unable to parse arguments.")
   102  	}
   103  
   104  	if len(docs) > 0 {
   105  		return 1, errors.New("error: docs not supported in sql mode")
   106  	}
   107  
   108  	err = checkoutTables(ctx, dbData, tbls)
   109  
   110  	if err != nil && apr.NArg() == 1 {
   111  		err = checkoutRemoteBranch(ctx, dbData, name)
   112  	}
   113  
   114  	if err != nil {
   115  		return 1, err
   116  	}
   117  
   118  	return 0, nil
   119  }
   120  
   121  func checkoutRemoteBranch(ctx *sql.Context, dbData env.DbData, branchName string) error {
   122  	if len(branchName) == 0 {
   123  		return ErrEmptyBranchName
   124  	}
   125  
   126  	if ref, refExists, err := actions.GetRemoteBranchRef(ctx, dbData.Ddb, branchName); err != nil {
   127  		return errors.New("fatal: unable to read from data repository")
   128  	} else if refExists {
   129  		return checkoutNewBranch(ctx, dbData, branchName, ref.String())
   130  	} else {
   131  		return fmt.Errorf("error: could not find %s", branchName)
   132  	}
   133  }
   134  
   135  func checkoutNewBranch(ctx *sql.Context, dbData env.DbData, branchName, startPt string) error {
   136  	if len(branchName) == 0 {
   137  		return ErrEmptyBranchName
   138  	}
   139  
   140  	if startPt == "" {
   141  		startPt = "head"
   142  	}
   143  
   144  	err := actions.CreateBranchWithStartPt(ctx, dbData, branchName, startPt, false)
   145  	if err != nil {
   146  		return err
   147  	}
   148  
   149  	return checkoutBranch(ctx, dbData, branchName)
   150  }
   151  
   152  func checkoutBranch(ctx *sql.Context, dbData env.DbData, branchName string) error {
   153  	if len(branchName) == 0 {
   154  		return ErrEmptyBranchName
   155  	}
   156  
   157  	err := actions.CheckoutBranchWithoutDocs(ctx, dbData, branchName)
   158  
   159  	if err != nil {
   160  		if err == doltdb.ErrBranchNotFound {
   161  			return fmt.Errorf("fatal: Branch '%s' not found.", branchName)
   162  		} else if doltdb.IsRootValUnreachable(err) {
   163  			rt := doltdb.GetUnreachableRootType(err)
   164  			return fmt.Errorf("error: unable to read the %s", rt.String())
   165  		} else if actions.IsCheckoutWouldOverwrite(err) {
   166  			tbls := actions.CheckoutWouldOverwriteTables(err)
   167  			msg := "error: Your local changes to the following tables would be overwritten by checkout: \n"
   168  			for _, tbl := range tbls {
   169  				msg = msg + tbl + "\n"
   170  			}
   171  			return errors.New(msg)
   172  		} else if err == doltdb.ErrAlreadyOnBranch {
   173  			return nil // No need to return an error if on the same branch
   174  		} else {
   175  			return fmt.Errorf("fatal: Unexpected error checking out branch '%s'", branchName)
   176  		}
   177  	}
   178  
   179  	return updateHeadAndWorkingSessionVars(ctx, dbData)
   180  }
   181  
   182  func checkoutTables(ctx *sql.Context, dbData env.DbData, tables []string) error {
   183  	err := actions.CheckoutTables(ctx, dbData, tables)
   184  
   185  	if err != nil {
   186  		if doltdb.IsRootValUnreachable(err) {
   187  			rt := doltdb.GetUnreachableRootType(err)
   188  			return fmt.Errorf("error: unable to read the %s", rt.String())
   189  		} else if actions.IsTblNotExist(err) {
   190  			return fmt.Errorf("error: given tables do not exist")
   191  		} else {
   192  			return fmt.Errorf("fatal: Unexpected error checking out tables")
   193  		}
   194  	}
   195  
   196  	return updateHeadAndWorkingSessionVars(ctx, dbData)
   197  }
   198  
   199  // updateHeadAndWorkingSessionVars explicitly sets the head and working hash.
   200  func updateHeadAndWorkingSessionVars(ctx *sql.Context, dbData env.DbData) error {
   201  	headHash, err := dbData.Rsr.CWBHeadHash(ctx)
   202  	if err != nil {
   203  		return err
   204  	}
   205  	hs := headHash.String()
   206  
   207  	hasWorkingChanges := hasWorkingSetChanges(dbData.Rsr)
   208  	hasStagedChanges, err := hasStagedSetChanges(ctx, dbData.Ddb, dbData.Rsr)
   209  
   210  	if err != nil {
   211  		return err
   212  	}
   213  
   214  	workingHash := dbData.Rsr.WorkingHash().String()
   215  
   216  	// This will update the session table editor's root and clear its cache.
   217  	if !hasStagedChanges && !hasWorkingChanges {
   218  		return setHeadAndWorkingSessionRoot(ctx, hs)
   219  	}
   220  
   221  	err = setSessionRootExplicit(ctx, hs, sqle.HeadKeySuffix)
   222  	if err != nil {
   223  		return err
   224  	}
   225  
   226  	return setSessionRootExplicit(ctx, workingHash, sqle.WorkingKeySuffix)
   227  }
   228  
   229  func (d DoltCheckoutFunc) String() string {
   230  	childrenStrings := make([]string, len(d.Children()))
   231  
   232  	for i, child := range d.Children() {
   233  		childrenStrings[i] = child.String()
   234  	}
   235  
   236  	return fmt.Sprintf("DOLT_CHECKOUT(%s)", strings.Join(childrenStrings, ","))
   237  }
   238  
   239  func (d DoltCheckoutFunc) Type() sql.Type {
   240  	return sql.Int8
   241  }
   242  
   243  func (d DoltCheckoutFunc) WithChildren(ctx *sql.Context, children ...sql.Expression) (sql.Expression, error) {
   244  	return NewDoltCheckoutFunc(ctx, children...)
   245  }
   246  
   247  func NewDoltCheckoutFunc(ctx *sql.Context, args ...sql.Expression) (sql.Expression, error) {
   248  	return &DoltCheckoutFunc{expression.NaryExpression{ChildExpressions: args}}, nil
   249  }