github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_pull.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dprocedures
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"sync"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	gmstypes "github.com/dolthub/go-mysql-server/sql/types"
    25  
    26  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/ref"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    34  	"github.com/dolthub/dolt/go/store/datas/pull"
    35  )
    36  
    37  // For callers of dolt_pull(), the index of the FastForward column is needed to print results. If the schema of
    38  // the result changes, this will need to be updated.
    39  const PullProcFFIndex = 0
    40  
    41  var doltPullSchema = []*sql.Column{
    42  	{
    43  		Name:     "fast_forward",
    44  		Type:     gmstypes.Int64,
    45  		Nullable: false,
    46  	},
    47  	{
    48  		Name:     "conflicts",
    49  		Type:     gmstypes.Int64,
    50  		Nullable: false,
    51  	},
    52  	{
    53  		Name:     "message",
    54  		Type:     gmstypes.LongText,
    55  		Nullable: true,
    56  	},
    57  }
    58  
    59  // doltPull is the stored procedure version for the CLI command `dolt pull`.
    60  func doltPull(ctx *sql.Context, args ...string) (sql.RowIter, error) {
    61  	conflicts, ff, msg, err := doDoltPull(ctx, args)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	if msg == "" {
    67  		return rowToIter(int64(ff), int64(conflicts), nil), nil
    68  	}
    69  	return rowToIter(int64(ff), int64(conflicts), msg), nil
    70  }
    71  
    72  // doDoltPull returns conflicts, fast_forward statuses
    73  func doDoltPull(ctx *sql.Context, args []string) (int, int, string, error) {
    74  	dbName := ctx.GetCurrentDatabase()
    75  
    76  	if len(dbName) == 0 {
    77  		return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("empty database name.")
    78  	}
    79  	if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
    80  		return noConflictsOrViolations, threeWayMerge, "", err
    81  	}
    82  
    83  	sess := dsess.DSessFromSess(ctx.Session)
    84  	dbData, ok := sess.GetDbData(ctx, dbName)
    85  	if !ok {
    86  		return noConflictsOrViolations, threeWayMerge, "", sql.ErrDatabaseNotFound.New(dbName)
    87  	}
    88  
    89  	apr, err := cli.CreatePullArgParser().Parse(args)
    90  	if err != nil {
    91  		return noConflictsOrViolations, threeWayMerge, "", err
    92  	}
    93  
    94  	if apr.NArg() > 2 {
    95  		return noConflictsOrViolations, threeWayMerge, "", actions.ErrInvalidPullArgs
    96  	}
    97  
    98  	var remoteName, remoteRefName string
    99  	if apr.NArg() == 1 {
   100  		remoteName = apr.Arg(0)
   101  	} else if apr.NArg() == 2 {
   102  		remoteName = apr.Arg(0)
   103  		remoteRefName = apr.Arg(1)
   104  	}
   105  
   106  	remoteOnly := apr.NArg() == 1
   107  	pullSpec, err := env.NewPullSpec(
   108  		ctx,
   109  		dbData.Rsr,
   110  		remoteName,
   111  		remoteRefName,
   112  		remoteOnly,
   113  		env.WithSquash(apr.Contains(cli.SquashParam)),
   114  		env.WithNoFF(apr.Contains(cli.NoFFParam)),
   115  		env.WithNoCommit(apr.Contains(cli.NoCommitFlag)),
   116  		env.WithNoEdit(apr.Contains(cli.NoEditFlag)),
   117  		env.WithForce(apr.Contains(cli.ForceFlag)),
   118  	)
   119  	if err != nil {
   120  		return noConflictsOrViolations, threeWayMerge, "", err
   121  	}
   122  
   123  	if user, hasUser := apr.GetValue(cli.UserFlag); hasUser {
   124  		pullSpec.Remote = pullSpec.Remote.WithParams(map[string]string{
   125  			dbfactory.GRPCUsernameAuthParam: user,
   126  		})
   127  	}
   128  
   129  	srcDB, err := sess.Provider().GetRemoteDB(ctx, dbData.Ddb.ValueReadWriter().Format(), pullSpec.Remote, false)
   130  	if err != nil {
   131  		return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("failed to get remote db; %w", err)
   132  	}
   133  
   134  	ws, err := sess.WorkingSet(ctx, dbName)
   135  	if err != nil {
   136  		return noConflictsOrViolations, threeWayMerge, "", err
   137  	}
   138  
   139  	// Fetch all references
   140  	branchRefs, err := srcDB.GetHeadRefs(ctx)
   141  	if err != nil {
   142  		return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("%w: %s", env.ErrFailedToReadDb, err.Error())
   143  	}
   144  
   145  	_, hasBranch, err := srcDB.HasBranch(ctx, pullSpec.Branch.GetPath())
   146  	if err != nil {
   147  		return noConflictsOrViolations, threeWayMerge, "", err
   148  	}
   149  	if !hasBranch {
   150  		return noConflictsOrViolations, threeWayMerge, "",
   151  			fmt.Errorf("branch %q not found on remote", pullSpec.Branch.GetPath())
   152  	}
   153  
   154  	mode := ref.UpdateMode{Force: true, Prune: false}
   155  	err = actions.FetchRefSpecs(ctx, dbData, srcDB, pullSpec.RefSpecs, &pullSpec.Remote, mode, runProgFuncs, stopProgFuncs)
   156  	if err != nil {
   157  		return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("fetch failed: %w", err)
   158  	}
   159  
   160  	var conflicts int
   161  	var fastForward int
   162  	var message string
   163  	for _, refSpec := range pullSpec.RefSpecs {
   164  		rsSeen := false // track invalid refSpecs
   165  		for _, branchRef := range branchRefs {
   166  			remoteTrackRef := refSpec.DestRef(branchRef)
   167  
   168  			if remoteTrackRef == nil {
   169  				continue
   170  			}
   171  
   172  			if branchRef != pullSpec.Branch {
   173  				continue
   174  			}
   175  
   176  			rsSeen = true
   177  
   178  			headRef, err := dbData.Rsr.CWBHeadRef()
   179  			if err != nil {
   180  				return noConflictsOrViolations, threeWayMerge, "", err
   181  			}
   182  
   183  			msg := fmt.Sprintf("Merge branch '%s' of %s into %s", pullSpec.Branch.GetPath(), pullSpec.Remote.Url, headRef.GetPath())
   184  
   185  			roots, ok := sess.GetRoots(ctx, dbName)
   186  			if !ok {
   187  				return noConflictsOrViolations, threeWayMerge, "", sql.ErrDatabaseNotFound.New(dbName)
   188  			}
   189  
   190  			mergeSpec, err := createMergeSpec(ctx, sess, dbName, apr, remoteTrackRef.String())
   191  			if err != nil {
   192  				return noConflictsOrViolations, threeWayMerge, "", err
   193  			}
   194  
   195  			uncommittedChanges, _, _, err := actions.RootHasUncommittedChanges(roots)
   196  			if err != nil {
   197  				return noConflictsOrViolations, threeWayMerge, "", err
   198  			}
   199  			if uncommittedChanges {
   200  				return noConflictsOrViolations, threeWayMerge, "", ErrUncommittedChanges.New()
   201  			}
   202  
   203  			ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg)
   204  			if err != nil && !errors.Is(doltdb.ErrUpToDate, err) {
   205  				return conflicts, fastForward, "", err
   206  			}
   207  
   208  			err = sess.SetWorkingSet(ctx, dbName, ws)
   209  			if err != nil {
   210  				return conflicts, fastForward, "", err
   211  			}
   212  		}
   213  		if !rsSeen {
   214  			return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("%w: '%s'", ref.ErrInvalidRefSpec, refSpec.GetRemRefToLocal())
   215  		}
   216  	}
   217  
   218  	tmpDir, err := dbData.Rsw.TempTableFilesDir()
   219  	if err != nil {
   220  		return noConflictsOrViolations, threeWayMerge, "", err
   221  	}
   222  	err = actions.FetchFollowTags(ctx, tmpDir, srcDB, dbData.Ddb, runProgFuncs, stopProgFuncs)
   223  	if err != nil {
   224  		return conflicts, fastForward, "", err
   225  	}
   226  
   227  	return conflicts, fastForward, message, nil
   228  }
   229  
   230  // TODO: remove this as it does not do anything useful
   231  func pullerProgFunc(ctx context.Context, statsCh <-chan pull.Stats) {
   232  	for {
   233  		select {
   234  		case <-ctx.Done():
   235  			return
   236  		case <-statsCh:
   237  		}
   238  	}
   239  }
   240  
   241  // TODO: remove this as it does not do anything useful
   242  func runProgFuncs(ctx context.Context) (*sync.WaitGroup, chan pull.Stats) {
   243  	statsCh := make(chan pull.Stats)
   244  	wg := &sync.WaitGroup{}
   245  
   246  	wg.Add(1)
   247  	go func() {
   248  		defer wg.Done()
   249  		pullerProgFunc(ctx, statsCh)
   250  	}()
   251  
   252  	return wg, statsCh
   253  }
   254  
   255  // TODO: remove this as it does not do anything useful
   256  func stopProgFuncs(cancel context.CancelFunc, wg *sync.WaitGroup, statsCh chan pull.Stats) {
   257  	cancel()
   258  	close(statsCh)
   259  	wg.Wait()
   260  }