github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_pull.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dprocedures 16 17 import ( 18 "context" 19 "errors" 20 "fmt" 21 "sync" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 gmstypes "github.com/dolthub/go-mysql-server/sql/types" 25 26 "github.com/dolthub/dolt/go/cmd/dolt/cli" 27 "github.com/dolthub/dolt/go/libraries/doltcore/branch_control" 28 "github.com/dolthub/dolt/go/libraries/doltcore/dbfactory" 29 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 30 "github.com/dolthub/dolt/go/libraries/doltcore/env" 31 "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" 32 "github.com/dolthub/dolt/go/libraries/doltcore/ref" 33 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" 34 "github.com/dolthub/dolt/go/store/datas/pull" 35 ) 36 37 // For callers of dolt_pull(), the index of the FastForward column is needed to print results. If the schema of 38 // the result changes, this will need to be updated. 39 const PullProcFFIndex = 0 40 41 var doltPullSchema = []*sql.Column{ 42 { 43 Name: "fast_forward", 44 Type: gmstypes.Int64, 45 Nullable: false, 46 }, 47 { 48 Name: "conflicts", 49 Type: gmstypes.Int64, 50 Nullable: false, 51 }, 52 { 53 Name: "message", 54 Type: gmstypes.LongText, 55 Nullable: true, 56 }, 57 } 58 59 // doltPull is the stored procedure version for the CLI command `dolt pull`. 60 func doltPull(ctx *sql.Context, args ...string) (sql.RowIter, error) { 61 conflicts, ff, msg, err := doDoltPull(ctx, args) 62 if err != nil { 63 return nil, err 64 } 65 66 if msg == "" { 67 return rowToIter(int64(ff), int64(conflicts), nil), nil 68 } 69 return rowToIter(int64(ff), int64(conflicts), msg), nil 70 } 71 72 // doDoltPull returns conflicts, fast_forward statuses 73 func doDoltPull(ctx *sql.Context, args []string) (int, int, string, error) { 74 dbName := ctx.GetCurrentDatabase() 75 76 if len(dbName) == 0 { 77 return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("empty database name.") 78 } 79 if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { 80 return noConflictsOrViolations, threeWayMerge, "", err 81 } 82 83 sess := dsess.DSessFromSess(ctx.Session) 84 dbData, ok := sess.GetDbData(ctx, dbName) 85 if !ok { 86 return noConflictsOrViolations, threeWayMerge, "", sql.ErrDatabaseNotFound.New(dbName) 87 } 88 89 apr, err := cli.CreatePullArgParser().Parse(args) 90 if err != nil { 91 return noConflictsOrViolations, threeWayMerge, "", err 92 } 93 94 if apr.NArg() > 2 { 95 return noConflictsOrViolations, threeWayMerge, "", actions.ErrInvalidPullArgs 96 } 97 98 var remoteName, remoteRefName string 99 if apr.NArg() == 1 { 100 remoteName = apr.Arg(0) 101 } else if apr.NArg() == 2 { 102 remoteName = apr.Arg(0) 103 remoteRefName = apr.Arg(1) 104 } 105 106 remoteOnly := apr.NArg() == 1 107 pullSpec, err := env.NewPullSpec( 108 ctx, 109 dbData.Rsr, 110 remoteName, 111 remoteRefName, 112 remoteOnly, 113 env.WithSquash(apr.Contains(cli.SquashParam)), 114 env.WithNoFF(apr.Contains(cli.NoFFParam)), 115 env.WithNoCommit(apr.Contains(cli.NoCommitFlag)), 116 env.WithNoEdit(apr.Contains(cli.NoEditFlag)), 117 env.WithForce(apr.Contains(cli.ForceFlag)), 118 ) 119 if err != nil { 120 return noConflictsOrViolations, threeWayMerge, "", err 121 } 122 123 if user, hasUser := apr.GetValue(cli.UserFlag); hasUser { 124 pullSpec.Remote = pullSpec.Remote.WithParams(map[string]string{ 125 dbfactory.GRPCUsernameAuthParam: user, 126 }) 127 } 128 129 srcDB, err := sess.Provider().GetRemoteDB(ctx, dbData.Ddb.ValueReadWriter().Format(), pullSpec.Remote, false) 130 if err != nil { 131 return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("failed to get remote db; %w", err) 132 } 133 134 ws, err := sess.WorkingSet(ctx, dbName) 135 if err != nil { 136 return noConflictsOrViolations, threeWayMerge, "", err 137 } 138 139 // Fetch all references 140 branchRefs, err := srcDB.GetHeadRefs(ctx) 141 if err != nil { 142 return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("%w: %s", env.ErrFailedToReadDb, err.Error()) 143 } 144 145 _, hasBranch, err := srcDB.HasBranch(ctx, pullSpec.Branch.GetPath()) 146 if err != nil { 147 return noConflictsOrViolations, threeWayMerge, "", err 148 } 149 if !hasBranch { 150 return noConflictsOrViolations, threeWayMerge, "", 151 fmt.Errorf("branch %q not found on remote", pullSpec.Branch.GetPath()) 152 } 153 154 mode := ref.UpdateMode{Force: true, Prune: false} 155 err = actions.FetchRefSpecs(ctx, dbData, srcDB, pullSpec.RefSpecs, &pullSpec.Remote, mode, runProgFuncs, stopProgFuncs) 156 if err != nil { 157 return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("fetch failed: %w", err) 158 } 159 160 var conflicts int 161 var fastForward int 162 var message string 163 for _, refSpec := range pullSpec.RefSpecs { 164 rsSeen := false // track invalid refSpecs 165 for _, branchRef := range branchRefs { 166 remoteTrackRef := refSpec.DestRef(branchRef) 167 168 if remoteTrackRef == nil { 169 continue 170 } 171 172 if branchRef != pullSpec.Branch { 173 continue 174 } 175 176 rsSeen = true 177 178 headRef, err := dbData.Rsr.CWBHeadRef() 179 if err != nil { 180 return noConflictsOrViolations, threeWayMerge, "", err 181 } 182 183 msg := fmt.Sprintf("Merge branch '%s' of %s into %s", pullSpec.Branch.GetPath(), pullSpec.Remote.Url, headRef.GetPath()) 184 185 roots, ok := sess.GetRoots(ctx, dbName) 186 if !ok { 187 return noConflictsOrViolations, threeWayMerge, "", sql.ErrDatabaseNotFound.New(dbName) 188 } 189 190 mergeSpec, err := createMergeSpec(ctx, sess, dbName, apr, remoteTrackRef.String()) 191 if err != nil { 192 return noConflictsOrViolations, threeWayMerge, "", err 193 } 194 195 uncommittedChanges, _, _, err := actions.RootHasUncommittedChanges(roots) 196 if err != nil { 197 return noConflictsOrViolations, threeWayMerge, "", err 198 } 199 if uncommittedChanges { 200 return noConflictsOrViolations, threeWayMerge, "", ErrUncommittedChanges.New() 201 } 202 203 ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg) 204 if err != nil && !errors.Is(doltdb.ErrUpToDate, err) { 205 return conflicts, fastForward, "", err 206 } 207 208 err = sess.SetWorkingSet(ctx, dbName, ws) 209 if err != nil { 210 return conflicts, fastForward, "", err 211 } 212 } 213 if !rsSeen { 214 return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("%w: '%s'", ref.ErrInvalidRefSpec, refSpec.GetRemRefToLocal()) 215 } 216 } 217 218 tmpDir, err := dbData.Rsw.TempTableFilesDir() 219 if err != nil { 220 return noConflictsOrViolations, threeWayMerge, "", err 221 } 222 err = actions.FetchFollowTags(ctx, tmpDir, srcDB, dbData.Ddb, runProgFuncs, stopProgFuncs) 223 if err != nil { 224 return conflicts, fastForward, "", err 225 } 226 227 return conflicts, fastForward, message, nil 228 } 229 230 // TODO: remove this as it does not do anything useful 231 func pullerProgFunc(ctx context.Context, statsCh <-chan pull.Stats) { 232 for { 233 select { 234 case <-ctx.Done(): 235 return 236 case <-statsCh: 237 } 238 } 239 } 240 241 // TODO: remove this as it does not do anything useful 242 func runProgFuncs(ctx context.Context) (*sync.WaitGroup, chan pull.Stats) { 243 statsCh := make(chan pull.Stats) 244 wg := &sync.WaitGroup{} 245 246 wg.Add(1) 247 go func() { 248 defer wg.Done() 249 pullerProgFunc(ctx, statsCh) 250 }() 251 252 return wg, statsCh 253 } 254 255 // TODO: remove this as it does not do anything useful 256 func stopProgFuncs(cancel context.CancelFunc, wg *sync.WaitGroup, statsCh chan pull.Stats) { 257 cancel() 258 close(statsCh) 259 wg.Wait() 260 }