github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_reset.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dprocedures
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  
    22  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    27  )
    28  
    29  // doltReset is the stored procedure version for the CLI command `dolt reset`.
    30  func doltReset(ctx *sql.Context, args ...string) (sql.RowIter, error) {
    31  	res, err := doDoltReset(ctx, args)
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  	return rowToIter(int64(res)), nil
    36  }
    37  
    38  func doDoltReset(ctx *sql.Context, args []string) (int, error) {
    39  	dbName := ctx.GetCurrentDatabase()
    40  
    41  	if len(dbName) == 0 {
    42  		return 1, fmt.Errorf("Empty database name.")
    43  	}
    44  	if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
    45  		return 1, err
    46  	}
    47  
    48  	dSess := dsess.DSessFromSess(ctx.Session)
    49  	dbData, ok := dSess.GetDbData(ctx, dbName)
    50  
    51  	if !ok {
    52  		return 1, fmt.Errorf("Could not load database %s", dbName)
    53  	}
    54  
    55  	apr, err := cli.CreateResetArgParser().Parse(args)
    56  	if err != nil {
    57  		return 1, err
    58  	}
    59  
    60  	// Check if problems with args first.
    61  	if apr.ContainsAll(cli.HardResetParam, cli.SoftResetParam) {
    62  		return 1, fmt.Errorf("error: --%s and --%s are mutually exclusive options.", cli.HardResetParam, cli.SoftResetParam)
    63  	}
    64  
    65  	// Disallow manipulating any roots for read-only databases – changing the branch
    66  	// HEAD would allow changing data, and working set and index shouldn't ever have
    67  	// any contents for a read-only database.
    68  	isReadOnly, err := isReadOnlyDatabase(ctx, dbName)
    69  	if err != nil {
    70  		return 1, err
    71  	}
    72  	if isReadOnly {
    73  		return 1, fmt.Errorf("unable to reset HEAD in read-only databases")
    74  	}
    75  
    76  	// Get all the needed roots.
    77  	roots, ok := dSess.GetRoots(ctx, dbName)
    78  	if !ok {
    79  		return 1, fmt.Errorf("Could not load database %s", dbName)
    80  	}
    81  
    82  	if apr.Contains(cli.HardResetParam) {
    83  		// Get the commitSpec for the branch if it exists
    84  		arg := ""
    85  		if apr.NArg() > 1 {
    86  			return 1, fmt.Errorf("--hard supports at most one additional param")
    87  		} else if apr.NArg() == 1 {
    88  			arg = apr.Arg(0)
    89  		}
    90  
    91  		var newHead *doltdb.Commit
    92  		newHead, roots, err = actions.ResetHardTables(ctx, dbData, arg, roots)
    93  		if err != nil {
    94  			return 1, err
    95  		}
    96  
    97  		// TODO: this overrides the transaction setting, needs to happen at commit, not here
    98  		if newHead != nil {
    99  			headRef, err := dbData.Rsr.CWBHeadRef()
   100  			if err != nil {
   101  				return 1, err
   102  			}
   103  			if err := dbData.Ddb.SetHeadToCommit(ctx, headRef, newHead); err != nil {
   104  				return 1, err
   105  			}
   106  		}
   107  
   108  		// TODO - refactor and make transactional with the head update above.
   109  		ws, err := dSess.WorkingSet(ctx, dbName)
   110  		if err != nil {
   111  			return 1, err
   112  		}
   113  		err = dSess.SetWorkingSet(ctx, dbName, ws.WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged).ClearMerge().ClearRebase())
   114  		if err != nil {
   115  			return 1, err
   116  		}
   117  
   118  	} else if apr.Contains(cli.SoftResetParam) {
   119  		arg := ""
   120  		if apr.NArg() > 1 {
   121  			return 1, fmt.Errorf("--soft supports at most one additional param")
   122  		} else if apr.NArg() == 1 {
   123  			arg = apr.Arg(0)
   124  		}
   125  
   126  		if arg != "" {
   127  			roots, err = actions.ResetSoftToRef(ctx, dbData, arg)
   128  			if err != nil {
   129  				return 1, err
   130  			}
   131  			ws, err := dSess.WorkingSet(ctx, dbName)
   132  			if err != nil {
   133  				return 1, err
   134  			}
   135  			err = dSess.SetWorkingSet(ctx, dbName, ws.WithStagedRoot(roots.Staged).ClearMerge().ClearRebase())
   136  			if err != nil {
   137  				return 1, err
   138  			}
   139  		}
   140  	} else {
   141  		if apr.NArg() != 1 || (apr.NArg() == 1 && apr.Arg(0) == ".") {
   142  			roots, err = actions.ResetSoftTables(ctx, dbData, apr, roots)
   143  			if err != nil {
   144  				return 1, err
   145  			}
   146  			err = dSess.SetRoots(ctx, dbName, roots)
   147  			if err != nil {
   148  				return 1, err
   149  			}
   150  		} else {
   151  			// check if the input is a table name or commit ref
   152  			_, okHead, _ := roots.Head.ResolveTableName(ctx, apr.Arg(0))
   153  			_, okStaged, _ := roots.Staged.ResolveTableName(ctx, apr.Arg(0))
   154  			_, okWorking, _ := roots.Working.ResolveTableName(ctx, apr.Arg(0))
   155  			if okHead || okStaged || okWorking {
   156  				roots, err = actions.ResetSoftTables(ctx, dbData, apr, roots)
   157  				if err != nil {
   158  					return 1, err
   159  				}
   160  				err = dSess.SetRoots(ctx, dbName, roots)
   161  				if err != nil {
   162  					return 1, err
   163  				}
   164  			} else {
   165  				roots, err = actions.ResetSoftToRef(ctx, dbData, apr.Arg(0))
   166  				if err != nil {
   167  					return 1, err
   168  				}
   169  				ws, err := dSess.WorkingSet(ctx, dbName)
   170  				if err != nil {
   171  					return 1, err
   172  				}
   173  				err = dSess.SetWorkingSet(ctx, dbName, ws.WithStagedRoot(roots.Staged).ClearMerge().ClearRebase())
   174  				if err != nil {
   175  					return 1, err
   176  				}
   177  			}
   178  		}
   179  
   180  	}
   181  
   182  	return 0, nil
   183  }