github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/dtestutils/testcommands/command.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testcommands
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  
    27  	"github.com/dolthub/dolt/go/cmd/dolt/commands/cnfcmds"
    28  	"github.com/dolthub/dolt/go/cmd/dolt/errhand"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/merge"
    33  	dsqle "github.com/dolthub/dolt/go/libraries/doltcore/sqle"
    34  )
    35  
    36  type Command interface {
    37  	CommandString() string
    38  	Exec(t *testing.T, dEnv *env.DoltEnv) error
    39  }
    40  
    41  type StageAll struct{}
    42  
    43  // CommandString describes the StageAll command for debugging purposes.
    44  func (a StageAll) CommandString() string { return "stage_all" }
    45  
    46  // Exec executes a StageAll command on a test dolt environment.
    47  func (a StageAll) Exec(t *testing.T, dEnv *env.DoltEnv) error {
    48  	return actions.StageAllTables(context.Background(), dEnv.DbData())
    49  }
    50  
    51  type CommitStaged struct {
    52  	Message string
    53  }
    54  
    55  // CommandString describes the CommitStaged command for debugging purposes.
    56  func (c CommitStaged) CommandString() string { return fmt.Sprintf("commit_staged: %s", c.Message) }
    57  
    58  // Exec executes a CommitStaged command on a test dolt environment.
    59  func (c CommitStaged) Exec(t *testing.T, dEnv *env.DoltEnv) error {
    60  	name, email, err := actions.GetNameAndEmail(dEnv.Config)
    61  
    62  	if err != nil {
    63  		return err
    64  	}
    65  
    66  	dbData := dEnv.DbData()
    67  
    68  	_, err = actions.CommitStaged(context.Background(), dbData, actions.CommitStagedProps{
    69  		Message:          c.Message,
    70  		Date:             time.Now(),
    71  		AllowEmpty:       false,
    72  		CheckForeignKeys: true,
    73  		Name:             name,
    74  		Email:            email,
    75  	})
    76  
    77  	return err
    78  }
    79  
    80  type CommitAll struct {
    81  	Message string
    82  }
    83  
    84  // CommandString describes the CommitAll command for debugging purposes.
    85  func (c CommitAll) CommandString() string { return fmt.Sprintf("commit: %s", c.Message) }
    86  
    87  // Exec executes a CommitAll command on a test dolt environment.
    88  func (c CommitAll) Exec(t *testing.T, dEnv *env.DoltEnv) error {
    89  	err := actions.StageAllTables(context.Background(), dEnv.DbData())
    90  	require.NoError(t, err)
    91  
    92  	name, email, err := actions.GetNameAndEmail(dEnv.Config)
    93  
    94  	if err != nil {
    95  		return err
    96  	}
    97  
    98  	dbData := dEnv.DbData()
    99  
   100  	_, err = actions.CommitStaged(context.Background(), dbData, actions.CommitStagedProps{
   101  		Message:          c.Message,
   102  		Date:             time.Now(),
   103  		AllowEmpty:       false,
   104  		CheckForeignKeys: true,
   105  		Name:             name,
   106  		Email:            email,
   107  	})
   108  
   109  	return err
   110  }
   111  
   112  type ResetHard struct{}
   113  
   114  // CommandString describes the ResetHard command for debugging purposes.
   115  func (r ResetHard) CommandString() string { return "reset_hard" }
   116  
   117  // NOTE: does not handle untracked tables
   118  func (r ResetHard) Exec(t *testing.T, dEnv *env.DoltEnv) error {
   119  	headRoot, err := dEnv.HeadRoot(context.Background())
   120  	if err != nil {
   121  		return err
   122  	}
   123  
   124  	err = dEnv.UpdateWorkingRoot(context.Background(), headRoot)
   125  	if err != nil {
   126  		return err
   127  	}
   128  
   129  	_, err = dEnv.UpdateStagedRoot(context.Background(), headRoot)
   130  	if err != nil {
   131  		return err
   132  	}
   133  
   134  	err = actions.SaveTrackedDocsFromWorking(context.Background(), dEnv)
   135  	return err
   136  }
   137  
   138  type Query struct {
   139  	Query string
   140  }
   141  
   142  // CommandString describes the Query command for debugging purposes.
   143  func (q Query) CommandString() string { return fmt.Sprintf("query %s", q.Query) }
   144  
   145  // Exec executes a Query command on a test dolt environment.
   146  func (q Query) Exec(t *testing.T, dEnv *env.DoltEnv) error {
   147  	root, err := dEnv.WorkingRoot(context.Background())
   148  	require.NoError(t, err)
   149  	sqlDb := dsqle.NewDatabase("dolt", dEnv.DbData())
   150  	engine, sqlCtx, err := dsqle.NewTestEngine(context.Background(), sqlDb, root)
   151  	require.NoError(t, err)
   152  
   153  	_, iter, err := engine.Query(sqlCtx, q.Query)
   154  	if err != nil {
   155  		return err
   156  	}
   157  
   158  	for {
   159  		_, err := iter.Next()
   160  		if err == io.EOF {
   161  			break
   162  		}
   163  		if err != nil {
   164  			return err
   165  		}
   166  	}
   167  
   168  	err = iter.Close(sqlCtx)
   169  	if err != nil {
   170  		return err
   171  	}
   172  
   173  	newRoot, err := sqlDb.GetRoot(sqlCtx)
   174  	require.NoError(t, err)
   175  
   176  	err = dEnv.UpdateWorkingRoot(context.Background(), newRoot)
   177  	return err
   178  }
   179  
   180  type Branch struct {
   181  	BranchName string
   182  }
   183  
   184  // CommandString describes the Branch command for debugging purposes.
   185  func (b Branch) CommandString() string { return fmt.Sprintf("branch: %s", b.BranchName) }
   186  
   187  // Exec executes a Branch command on a test dolt environment.
   188  func (b Branch) Exec(_ *testing.T, dEnv *env.DoltEnv) error {
   189  	cwb := dEnv.RepoState.Head.Ref.String()
   190  	return actions.CreateBranchWithStartPt(context.Background(), dEnv.DbData(), b.BranchName, cwb, false)
   191  }
   192  
   193  type Checkout struct {
   194  	BranchName string
   195  }
   196  
   197  // CommandString describes the Checkout command for debugging purposes.
   198  func (c Checkout) CommandString() string { return fmt.Sprintf("checkout: %s", c.BranchName) }
   199  
   200  // Exec executes a Checkout command on a test dolt environment.
   201  func (c Checkout) Exec(_ *testing.T, dEnv *env.DoltEnv) error {
   202  	return actions.CheckoutBranch(context.Background(), dEnv, c.BranchName)
   203  }
   204  
   205  type Merge struct {
   206  	BranchName string
   207  }
   208  
   209  // CommandString describes the Merge command for debugging purposes.
   210  func (m Merge) CommandString() string { return fmt.Sprintf("merge: %s", m.BranchName) }
   211  
   212  // Exec executes a Merge command on a test dolt environment.
   213  func (m Merge) Exec(t *testing.T, dEnv *env.DoltEnv) error {
   214  	// Adapted from commands/merge.go:Exec()
   215  	dref, err := dEnv.FindRef(context.Background(), m.BranchName)
   216  	assert.NoError(t, err)
   217  
   218  	cm1 := resolveCommit(t, "HEAD", dEnv)
   219  	cm2 := resolveCommit(t, dref.String(), dEnv)
   220  
   221  	h1, err := cm1.HashOf()
   222  	assert.NoError(t, err)
   223  
   224  	h2, err := cm2.HashOf()
   225  	assert.NoError(t, err)
   226  	assert.NotEqual(t, h1, h2)
   227  
   228  	tblNames, _, err := env.MergeWouldStompChanges(context.Background(), cm2, dEnv.DbData())
   229  	if err != nil {
   230  		return err
   231  	}
   232  	if len(tblNames) != 0 {
   233  		return errhand.BuildDError("error: failed to determine mergability.").AddCause(err).Build()
   234  	}
   235  
   236  	if ok, err := cm1.CanFastForwardTo(context.Background(), cm2); ok {
   237  		if err != nil {
   238  			return err
   239  		}
   240  
   241  		rv, err := cm2.GetRootValue()
   242  		assert.NoError(t, err)
   243  
   244  		h, err := dEnv.DoltDB.WriteRootValue(context.Background(), rv)
   245  		assert.NoError(t, err)
   246  
   247  		err = dEnv.DoltDB.FastForward(context.Background(), dEnv.RepoState.CWBHeadRef(), cm2)
   248  		if err != nil {
   249  			return err
   250  		}
   251  
   252  		dEnv.RepoState.Working = h.String()
   253  		dEnv.RepoState.Staged = h.String()
   254  		err = dEnv.RepoState.Save(dEnv.FS)
   255  		assert.NoError(t, err)
   256  
   257  		err = actions.SaveTrackedDocsFromWorking(context.Background(), dEnv)
   258  		assert.NoError(t, err)
   259  
   260  	} else {
   261  		mergedRoot, tblToStats, err := merge.MergeCommits(context.Background(), cm1, cm2)
   262  		require.NoError(t, err)
   263  		for _, stats := range tblToStats {
   264  			require.True(t, stats.Conflicts == 0)
   265  		}
   266  
   267  		h2, err := cm2.HashOf()
   268  		require.NoError(t, err)
   269  
   270  		err = dEnv.RepoState.StartMerge(h2.String(), dEnv.FS)
   271  		if err != nil {
   272  			return err
   273  		}
   274  
   275  		err = dEnv.UpdateWorkingRoot(context.Background(), mergedRoot)
   276  		if err != nil {
   277  			return err
   278  		}
   279  
   280  		err = actions.SaveTrackedDocsFromWorking(context.Background(), dEnv)
   281  		if err != nil {
   282  			return err
   283  		}
   284  
   285  		_, err = dEnv.UpdateStagedRoot(context.Background(), mergedRoot)
   286  		if err != nil {
   287  			return err
   288  		}
   289  	}
   290  	return nil
   291  }
   292  
   293  func resolveCommit(t *testing.T, cSpecStr string, dEnv *env.DoltEnv) *doltdb.Commit {
   294  	cs, err := doltdb.NewCommitSpec(cSpecStr)
   295  	require.NoError(t, err)
   296  	cm, err := dEnv.DoltDB.Resolve(context.TODO(), cs, dEnv.RepoState.CWBHeadRef())
   297  	require.NoError(t, err)
   298  	return cm
   299  }
   300  
   301  type ConflictsCat struct {
   302  	TableName string
   303  }
   304  
   305  // CommandString describes the ConflictsCat command for debugging purposes.
   306  func (c ConflictsCat) CommandString() string { return fmt.Sprintf("conflicts_cat: %s", c.TableName) }
   307  
   308  // Exec executes a ConflictsCat command on a test dolt environment.
   309  func (c ConflictsCat) Exec(t *testing.T, dEnv *env.DoltEnv) error {
   310  	out := cnfcmds.CatCmd{}.Exec(context.Background(), "dolt conflicts cat", []string{c.TableName}, dEnv)
   311  	require.Equal(t, 0, out)
   312  	return nil
   313  }